diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 9308a4157..173a06a0c 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,6 +1,6 @@ # See here for image contents: https://github.com/microsoft/vscode-dev-containers/tree/v0.192.0/containers/python-3/.devcontainer/base.Dockerfile -# [Choice] Python version: 3, 3.9, 3.8, 3.7, 3.6 +# [Choice] Python version: 3, 3.9, 3.8, 3.7 ARG VARIANT="3.9" FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 812f68cc4..545834d34 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -6,7 +6,7 @@ "dockerfile": "Dockerfile", "context": "..", "args": { - // Update 'VARIANT' to pick a Python version: 3, 3.6, 3.7, 3.8, 3.9 + // Update 'VARIANT' to pick a Python version: 3, 3.7, 3.8, 3.9 "VARIANT": "3", // Options "NODE_VERSION": "lts/*" diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md index b1c412062..2d72e49de 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.md +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -27,6 +27,7 @@ Please follow the instructions and template below to save us time requesting add - A detailed description. - A [Minimal Complete Reproducible Example](https://stackoverflow.com/help/mcve). This is code we can cut and paste into a readily available sample and run, or a link to a project you've written that we can compile to reproduce the bug. - Console logs. + - If this is a connection related issue, include logs from the [Connection Diagnostic Tool](https://github.com/Azure/azure-iot-connection-diagnostic-tool) 5. Delete these instructions before submitting the bug. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c2296054..28350c17c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,8 +4,14 @@ repos: hooks: - id: black language_version: python3 -- repo: https://gitlab.com/pycqa/flake8 +- repo: https://github.com/pycqa/flake8 rev: 3.9.1 # Use the ref you want to point at hooks: - id: flake8 - args: ['--config=.flake8'] \ No newline at end of file + args: ['--config=.flake8'] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.0.0 + hooks: + - id: mypy + files: azure-iot-device + exclude: tests diff --git a/README.md b/README.md index 5aedbe2da..4ea6932d6 100644 --- a/README.md +++ b/README.md @@ -8,30 +8,34 @@ The Azure IoT Device SDK for Python enables Python developers to easily create IoT device solutions that seamlessly connect to the Azure IoT Hub ecosystem. -* *If you're looking for the azure-iot-hub library, it is now located in the [azure-iot-hub-python](https://github.com/Azure/azure-iot-hub-python) repository* +* *If you're migrating v2.x.x code to use v3.x.x check the [IoT Hub Migration Guide](https://github.com/Azure/azure-iot-sdk-python/blob/main/migration_guide_iothub.md) and/or the [Provisioning Migration Guide](https://github.com/Azure/azure-iot-sdk-python/blob/main/migration_guide_provisioning.md)* + +* *If you're looking for the v2.x.x client library, it is now preserved in the [v2](https://github.com/Azure/azure-iot-sdk-python/tree/v2) branch* * *If you're looking for the v1.x.x client library, it is now preserved in the [v1-deprecated](https://github.com/Azure/azure-iot-sdk-python/tree/v1-deprecated) branch.* +* *If you're looking for the azure-iot-hub library, it is now located in the [azure-iot-hub-python](https://github.com/Azure/azure-iot-hub-python) repository* + +**NOTE: 3.x.x is still in beta and APIs are subject to change until the release of 3.0.0** + ## Installing the library The Azure IoT Device library is available on PyPI: ```Shell -pip install azure-iot-device +pip install azure-iot-device==3.0.0b2 ``` -Python 3.6 or higher is required in order to use the library +Python 3.7 or higher is required in order to use the library -## Using the library -API documentation for this package is available via [**Microsoft Docs**](https://docs.microsoft.com/python/api/azure-iot-device/azure.iot.device?view=azure-python). - -See our [**quickstart guide**](https://github.com/Azure/azure-iot-sdk-python/tree/main/samples/README.md) for step by step instructions for setting up and using an IoTHub with devices. -You can also view the [**samples repository**](https://github.com/Azure/azure-iot-sdk-python/tree/main/samples) to see additional examples of basic client usage. +## Using the library +You can view the [**samples directory**](https://github.com/Azure/azure-iot-sdk-python/tree/main/samples) to see examples of SDK usage. -Want to start off on the right foot? Be sure to learn about [**common pitfalls**](https://github.com/Azure/azure-iot-sdk-python/wiki/pitfalls) of using this Python SDK before starting a project. +Full API documentation for this package is available via [**Microsoft Docs**](https://docs.microsoft.com/python/api/azure-iot-device/azure.iot.device?view=azure-python). Note that this documentation may currently be out of date as v3.x.x is still in preview at the time of this writing. +You can use the [**Connection Diagnostic Tool**](https://github.com/Azure/azure-iot-connection-diagnostic-tool) to help ascertain the cause of any connection issues you run into when using the SDK. ## Features @@ -39,9 +43,9 @@ Want to start off on the right foot? Be sure to learn about [**common pitfalls** *Features that are not planned may be prioritized in a future release, but are not currently planned -These clients only support the **MQTT protocol**. +This library primarily uses the **MQTT protocol**. -### IoTHub Device Client +### IoTHubSession | Features | Status | Description | |------------------------------------------------------------------------------------------------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| @@ -50,51 +54,31 @@ These clients only support the **MQTT protocol**. | [Receive cloud-to-device messages](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-c2d) | :heavy_check_mark: | Receive cloud-to-device messages and read associated custom and system properties from IoT Hub. | | [Device Twins](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-device-twins) | :heavy_check_mark: | IoT Hub persists a device twin for each device that you connect to IoT Hub. The device can perform operations like get twin tags, subscribe to desired properties. | | [Direct Methods](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-direct-methods) | :heavy_check_mark: | IoT Hub gives you the ability to invoke direct methods on devices from the cloud. The SDK supports handler for method specific and generic operation. | -| [Connection Status and Error reporting](https://docs.microsoft.com/en-us/rest/api/iothub/common-error-codes) | :heavy_check_mark: | Error reporting for IoT Hub supported error code. | -| Connection Retry | :heavy_check_mark: | Dropped connections will be retried with a fixed 10 second interval by default. This functionality can be disabled if desired, and the interval can be configured | -| [Upload file to Blob](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-file-upload) | :heavy_check_mark: | A device can initiate a file upload and notifies IoT Hub when the upload is complete. | - -### IoTHub Module Client - -**Note:** IoT Edge for Python is scoped to Linux containers & devices only. [Learn more](https://techcommunity.microsoft.com/t5/internet-of-things/linux-modules-with-azure-iot-edge-on-windows-10-iot-enterprise/ba-p/1407066) about using Linux containers for IoT edge on Windows devices. - -| Features | Status | Description | -|------------------------------------------------------------------------------------------------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [Authentication](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-security-deployment) | :heavy_check_mark: | Connect your device to IoT Hub securely with supported authentication, including symmetric key, X-509 Self Signed, and Certificate Authority (CA) Signed. SASToken authentication is not currently supported. | -| [Send device-to-cloud message](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-d2c) | :heavy_check_mark: | Send device-to-cloud messages (max 256KB) to IoT Hub with the option to add custom properties. | -| [Receive cloud-to-device messages](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-c2d) | :heavy_check_mark: | Receive cloud-to-device messages and read associated custom and system properties from IoT Hub, with the option to complete/reject/abandon C2D messages. | -| [Device Twins](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-device-twins) | :heavy_check_mark: | IoT Hub persists a device twin for each device that you connect to IoT Hub. The device can perform operations like get twin tags, subscribe to desired properties. | -| [Direct Methods](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-direct-methods) | :heavy_check_mark: | IoT Hub gives you the ability to invoke direct methods on devices from the cloud. The SDK supports handler for method specific and generic operation. | -| [Connection Status and Error reporting](https://docs.microsoft.com/en-us/rest/api/iothub/common-error-codes) | :heavy_check_mark: | Error reporting for IoT Hub supported error code. | -| Connection Retry | :heavy_check_mark: | Dropped connections will be retried with a fixed 10 second interval. TThis functionality can be disabled if desired, and the interval can be configured | -| Direct Invocation of Method on Modules | :heavy_check_mark: | Invoke method calls to another module using using the Edge Gateway. | - -### Provisioning Device Client +| [Connection Status and Error reporting](https://docs.microsoft.com/en-us/rest/api/iothub/common-error-codes) | :heavy_check_mark: | Error reporting for IoT Hub supported error code. | +| Connection Retry | :heavy_check_mark: | Dropped connections will be retried with a fixed 10 second interval by default. This functionality can be disabled if desired, and the interval can be configured | +| [Upload file to Blob](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-file-upload) | :heavy_multiplication_x: | A device can initiate a file upload and notifies IoT Hub when the upload is complete. | +| Direct Invocation of Method on Modules | :heavy_multiplication_x: | Invoke method calls to another module using using the Edge Gateway. -| Features | Status | Description | -|-----------------------------|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| TPM Individual Enrollment | :heavy_minus_sign: | Provisioning via [Trusted Platform Module](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#trusted-platform-module-tpm). | -| X.509 Individual Enrollment | :heavy_check_mark: | Provisioning via [X.509 root certificate](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#root-certificate). Please review the [samples](azure-iot-device/samples/async-hub-scenarios/provision_x509.py) folder and this [quickstart](https://docs.microsoft.com/en-us/azure/iot-dps/quick-create-simulated-device-x509-python) on how to create a device client. | -| X.509 Enrollment Group | :heavy_check_mark: | Provisioning via [X.509 leaf certificate](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#leaf-certificate)). Please review the [samples](azure-iot-device/samples/async-hub-scenarios/provision_x509.py) folder on how to create a device client. | -| Symmetric Key Enrollment | :heavy_check_mark: | Provisioning via [Symmetric key attestation](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-symmetric-key-attestation)). Please review the [samples](azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key.py) folder on how to create a device client. | -## Critical Upcoming Changes Notice +### ProvisioningSession -### Certificates -All Azure IoT SDK users are advised to be aware of upcoming TLS certificate changes for Azure IoT Hub and Device Provisioning Service -that will impact the SDK's ability to connect to these services. In October 2022, both services will migrate from the current -[Baltimore CyberTrust CA Root](https://baltimore-cybertrust-root.chain-demos.digicert.com/info/index.html) to the -[DigiCert Global G2 CA root](https://global-root-g2.chain-demos.digicert.com/info/index.html). There will be a -transition period beforehand where your IoT devices must have both the Baltimore and Digicert public certificates -installed in their certificate store in order to prevent connectivity issues. +| Features | Status | Description | +|-----------------------------|--------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| TPM Individual Enrollment | :heavy_multiplication_x: | Provisioning via [Trusted Platform Module](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#trusted-platform-module-tpm). | +| X.509 Individual Enrollment | :heavy_check_mark: | Provisioning via [X.509 root certificate](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#root-certificate). Please review the [samples](azure-iot-device/samples/async-hub-scenarios/provision_x509.py) folder and this [quickstart](https://docs.microsoft.com/en-us/azure/iot-dps/quick-create-simulated-device-x509-python) on how to create a device client. | +| X.509 Enrollment Group | :heavy_check_mark: | Provisioning via [X.509 leaf certificate](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#leaf-certificate). Please review the [samples](azure-iot-device/samples/async-hub-scenarios/provision_x509.py) folder on how to create a device client. | +| Symmetric Key Enrollment | :heavy_check_mark: | Provisioning via [Symmetric key attestation](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-symmetric-key-attestation). Please review the [samples](azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key.py) folder on how to create a device client. | -**Devices with only the Baltimore public certificate installed will lose the ability to connect to Azure IoT hub and Device Provisioning Service in October 2022.** +## Support -To prepare for this change, make sure your device's certificate store has both of these public certificates installed. +The Azure IoT Hub Device Client supported releases is outlined in the following table. -For a more in depth explanation as to why the IoT services are doing this, please see -[this article](https://techcommunity.microsoft.com/t5/internet-of-things/azure-iot-tls-critical-changes-are-almost-here-and-why-you/ba-p/2393169). +Refer to the [Azure IoT Device SDK lifecycle and support](https://learn.microsoft.com/en-us/azure/iot/iot-device-sdks-lifecycle-and-support) for details on the different supported stages. +| Release | Category | End-of-life | +|-|-|-| +| 2.12.0 | Active | - | +| 3.0.0b2 | Preview | 2023-7-15 | ## Contributing diff --git a/azure-iot-device/azure/iot/device/__init__.py b/azure-iot-device/azure/iot/device/__init__.py index 6476c1eee..79770c5b2 100644 --- a/azure-iot-device/azure/iot/device/__init__.py +++ b/azure-iot-device/azure/iot/device/__init__.py @@ -4,42 +4,20 @@ from an IoT device. """ -# Import all exposed items in subpackages to expose them via this package -from .iothub import * # noqa: F401, F403 -from .provisioning import * # noqa: F401, F403 -from .common import * # noqa: F401, F403 TODO: do we really want to do this? - -# Import the subpackages themselves in order to set the __all__ -from . import iothub -from . import provisioning -from . import common - -# Import the module to generate missing documentation -from . import patch_documentation - - -# TODO: remove this chunk of commented code if we truly no longer want to take this approach - -# Dynamically patch the clients to add shim implementations for all the inherited methods. -# This is necessary to generate accurate online docs. -# It SHOULD not impact the functionality of the methods themselves in any way. - -# NOTE In the event of addition of new methods and generation of accurate documentation -# for those methods we have to append content to "patch_documentation.py" file. -# In order to do so please uncomment the "patch.add_shims" lines below, -# enable logging with level "DEBUG" in a python terminal and do -# "import azure.iot.device". The delta between the newly generated output -# and the existing content of "patch_documentation.py" should be appended to -# the function "execute_patch_for_sync" in "patch_documentation.py". -# Once done please again comment out the "patch.add_shims" lines below. - -# from . import patch -# patch.add_shims_for_inherited_methods(IoTHubDeviceClient) # noqa: F405 -# patch.add_shims_for_inherited_methods(IoTHubModuleClient) # noqa: F405 -# patch.add_shims_for_inherited_methods(ProvisioningDeviceClient) # noqa: F405 -patch_documentation.execute_patch_for_sync() - - -# iothub and common subpackages are still showing up in intellisense - -__all__ = iothub.__all__ + provisioning.__all__ + common.__all__ +from .iothub_session import IoTHubSession # noqa: F401 +from .provisioning_session import ProvisioningSession # noqa: F401 +from .exceptions import ( # noqa: F401 + IoTHubError, + IoTEdgeError, + IoTEdgeEnvironmentError, + ProvisioningServiceError, + SessionError, + IoTHubClientError, + MQTTError, + MQTTConnectionFailedError, + MQTTConnectionDroppedError, +) + +# TODO: directly here, or via the models module? +from .models import Message, DirectMethodRequest, DirectMethodResponse # noqa: F401 +from . import models # noqa: F401 diff --git a/azure-iot-device/azure/iot/device/aio/__init__.py b/azure-iot-device/azure/iot/device/aio/__init__.py deleted file mode 100644 index 582aff88b..000000000 --- a/azure-iot-device/azure/iot/device/aio/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Azure IoT Device Library - Asynchronous - -This library provides asynchronous clients for communicating with Azure IoT services -from an IoT device. -""" - -# Import all exposed items in aio subpackages to expose them via this package -from azure.iot.device.iothub.aio import * # noqa: F401, F403 -from azure.iot.device.provisioning.aio import * # noqa: F401, F403 - -# Import the subpackages themselves in order to set the __all__ -import azure.iot.device.iothub.aio -import azure.iot.device.provisioning.aio - -# Import the module to generate missing documentation -from . import patch_documentation - - -# TODO: remove this chunk of commented code if we truly no longer want to take this approach - -# Dynamically patch the clients to add shim implementations for all the inherited methods. -# This is necessary to generate accurate online docs. -# It SHOULD not impact the functionality of the methods themselves in any way. - -# NOTE In the event of addition of new methods and generation of accurate documentation -# for those methods we have to append content to "patch_documentation.py" file. -# In order to do so please uncomment the "patch.add_shims" lines below, -# enable logging with level "DEBUG" in a python terminal and do -# "import azure.iot.device". The delta between the newly generated output -# and the existing content of "patch_documentation.py" should be appended to -# the function "execute_patch_for_sync" in "patch_documentation.py". -# Once done please again comment out the "patch.add_shims" lines below. - -# from azure.iot.device import patch -# patch.add_shims_for_inherited_methods(IoTHubDeviceClient) # noqa: F405 -# patch.add_shims_for_inherited_methods(IoTHubModuleClient) # noqa: F405 -# patch.add_shims_for_inherited_methods(ProvisioningDeviceClient) # noqa: F405 - - -patch_documentation.execute_patch_for_async() - -__all__ = azure.iot.device.iothub.aio.__all__ + azure.iot.device.provisioning.aio.__all__ diff --git a/azure-iot-device/azure/iot/device/aio/patch_documentation.py b/azure-iot-device/azure/iot/device/aio/patch_documentation.py deleted file mode 100644 index 9c629270d..000000000 --- a/azure-iot-device/azure/iot/device/aio/patch_documentation.py +++ /dev/null @@ -1,314 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module provides hard coded patches used to modify items from the libraries. -Currently we have to do like this so that we don't use exec anywhere""" - - -def execute_patch_for_async(): - from azure.iot.device.iothub.aio.async_clients import IoTHubDeviceClient as IoTHubDeviceClient - - async def connect(self): - return await super(IoTHubDeviceClient, self).connect() - - connect.__doc__ = IoTHubDeviceClient.connect.__doc__ - setattr(IoTHubDeviceClient, "connect", connect) - - async def disconnect(self): - return await super(IoTHubDeviceClient, self).disconnect() - - disconnect.__doc__ = IoTHubDeviceClient.disconnect.__doc__ - setattr(IoTHubDeviceClient, "disconnect", disconnect) - - async def get_twin(self): - return await super(IoTHubDeviceClient, self).get_twin() - - get_twin.__doc__ = IoTHubDeviceClient.get_twin.__doc__ - setattr(IoTHubDeviceClient, "get_twin", get_twin) - - async def patch_twin_reported_properties(self, reported_properties_patch): - return await super(IoTHubDeviceClient, self).patch_twin_reported_properties( - reported_properties_patch - ) - - patch_twin_reported_properties.__doc__ = ( - IoTHubDeviceClient.patch_twin_reported_properties.__doc__ - ) - setattr(IoTHubDeviceClient, "patch_twin_reported_properties", patch_twin_reported_properties) - - def receive_method_request(self, method_name=None): - return super(IoTHubDeviceClient, self).receive_method_request(method_name) - - receive_method_request.__doc__ = IoTHubDeviceClient.receive_method_request.__doc__ - setattr(IoTHubDeviceClient, "receive_method_request", receive_method_request) - - def receive_twin_desired_properties_patch(self): - return super(IoTHubDeviceClient, self).receive_twin_desired_properties_patch() - - receive_twin_desired_properties_patch.__doc__ = ( - IoTHubDeviceClient.receive_twin_desired_properties_patch.__doc__ - ) - setattr( - IoTHubDeviceClient, - "receive_twin_desired_properties_patch", - receive_twin_desired_properties_patch, - ) - - async def send_message(self, message): - return await super(IoTHubDeviceClient, self).send_message(message) - - send_message.__doc__ = IoTHubDeviceClient.send_message.__doc__ - setattr(IoTHubDeviceClient, "send_message", send_message) - - async def send_method_response(self, method_response): - return await super(IoTHubDeviceClient, self).send_method_response(method_response) - - send_method_response.__doc__ = IoTHubDeviceClient.send_method_response.__doc__ - setattr(IoTHubDeviceClient, "send_method_response", send_method_response) - - async def shutdown(self): - return await super(IoTHubDeviceClient, self).shutdown() - - shutdown.__doc__ = IoTHubDeviceClient.shutdown.__doc__ - setattr(IoTHubDeviceClient, "shutdown", shutdown) - - async def update_sastoken(self, sastoken): - return await super(IoTHubDeviceClient, self).update_sastoken(sastoken) - - update_sastoken.__doc__ = IoTHubDeviceClient.update_sastoken.__doc__ - setattr(IoTHubDeviceClient, "update_sastoken", update_sastoken) - - def create_from_connection_string(cls, connection_string, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_connection_string( - connection_string, **kwargs - ) - - create_from_connection_string.__doc__ = IoTHubDeviceClient.create_from_connection_string.__doc__ - setattr( - IoTHubDeviceClient, - "create_from_connection_string", - classmethod(create_from_connection_string), - ) - - def create_from_sastoken(cls, sastoken, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_sastoken(sastoken, **kwargs) - - create_from_sastoken.__doc__ = IoTHubDeviceClient.create_from_sastoken.__doc__ - setattr(IoTHubDeviceClient, "create_from_sastoken", classmethod(create_from_sastoken)) - - def create_from_symmetric_key(cls, symmetric_key, hostname, device_id, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_symmetric_key( - symmetric_key, hostname, device_id, **kwargs - ) - - create_from_symmetric_key.__doc__ = IoTHubDeviceClient.create_from_symmetric_key.__doc__ - setattr(IoTHubDeviceClient, "create_from_symmetric_key", classmethod(create_from_symmetric_key)) - - def create_from_x509_certificate(cls, x509, hostname, device_id, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_x509_certificate( - x509, hostname, device_id, **kwargs - ) - - create_from_x509_certificate.__doc__ = IoTHubDeviceClient.create_from_x509_certificate.__doc__ - setattr( - IoTHubDeviceClient, - "create_from_x509_certificate", - classmethod(create_from_x509_certificate), - ) - setattr(IoTHubDeviceClient, "connected", IoTHubDeviceClient.connected) - setattr( - IoTHubDeviceClient, "on_background_exception", IoTHubDeviceClient.on_background_exception - ) - setattr( - IoTHubDeviceClient, - "on_connection_state_change", - IoTHubDeviceClient.on_connection_state_change, - ) - setattr(IoTHubDeviceClient, "on_message_received", IoTHubDeviceClient.on_message_received) - setattr( - IoTHubDeviceClient, - "on_method_request_received", - IoTHubDeviceClient.on_method_request_received, - ) - setattr( - IoTHubDeviceClient, "on_new_sastoken_required", IoTHubDeviceClient.on_new_sastoken_required - ) - setattr( - IoTHubDeviceClient, - "on_twin_desired_properties_patch_received", - IoTHubDeviceClient.on_twin_desired_properties_patch_received, - ) - from azure.iot.device.iothub.aio.async_clients import IoTHubModuleClient as IoTHubModuleClient - - async def connect(self): - return await super(IoTHubModuleClient, self).connect() - - connect.__doc__ = IoTHubModuleClient.connect.__doc__ - setattr(IoTHubModuleClient, "connect", connect) - - async def disconnect(self): - return await super(IoTHubModuleClient, self).disconnect() - - disconnect.__doc__ = IoTHubModuleClient.disconnect.__doc__ - setattr(IoTHubModuleClient, "disconnect", disconnect) - - async def get_twin(self): - return await super(IoTHubModuleClient, self).get_twin() - - get_twin.__doc__ = IoTHubModuleClient.get_twin.__doc__ - setattr(IoTHubModuleClient, "get_twin", get_twin) - - async def patch_twin_reported_properties(self, reported_properties_patch): - return await super(IoTHubModuleClient, self).patch_twin_reported_properties( - reported_properties_patch - ) - - patch_twin_reported_properties.__doc__ = ( - IoTHubModuleClient.patch_twin_reported_properties.__doc__ - ) - setattr(IoTHubModuleClient, "patch_twin_reported_properties", patch_twin_reported_properties) - - def receive_method_request(self, method_name=None): - return super(IoTHubModuleClient, self).receive_method_request(method_name) - - receive_method_request.__doc__ = IoTHubModuleClient.receive_method_request.__doc__ - setattr(IoTHubModuleClient, "receive_method_request", receive_method_request) - - def receive_twin_desired_properties_patch(self): - return super(IoTHubModuleClient, self).receive_twin_desired_properties_patch() - - receive_twin_desired_properties_patch.__doc__ = ( - IoTHubModuleClient.receive_twin_desired_properties_patch.__doc__ - ) - setattr( - IoTHubModuleClient, - "receive_twin_desired_properties_patch", - receive_twin_desired_properties_patch, - ) - - async def send_message(self, message): - return await super(IoTHubModuleClient, self).send_message(message) - - send_message.__doc__ = IoTHubModuleClient.send_message.__doc__ - setattr(IoTHubModuleClient, "send_message", send_message) - - async def send_method_response(self, method_response): - return await super(IoTHubModuleClient, self).send_method_response(method_response) - - send_method_response.__doc__ = IoTHubModuleClient.send_method_response.__doc__ - setattr(IoTHubModuleClient, "send_method_response", send_method_response) - - async def shutdown(self): - return await super(IoTHubModuleClient, self).shutdown() - - shutdown.__doc__ = IoTHubModuleClient.shutdown.__doc__ - setattr(IoTHubModuleClient, "shutdown", shutdown) - - async def update_sastoken(self, sastoken): - return await super(IoTHubModuleClient, self).update_sastoken(sastoken) - - update_sastoken.__doc__ = IoTHubModuleClient.update_sastoken.__doc__ - setattr(IoTHubModuleClient, "update_sastoken", update_sastoken) - - def create_from_connection_string(cls, connection_string, **kwargs): - return super(IoTHubModuleClient, cls).create_from_connection_string( - connection_string, **kwargs - ) - - create_from_connection_string.__doc__ = IoTHubModuleClient.create_from_connection_string.__doc__ - setattr( - IoTHubModuleClient, - "create_from_connection_string", - classmethod(create_from_connection_string), - ) - - def create_from_edge_environment(cls, **kwargs): - return super(IoTHubModuleClient, cls).create_from_edge_environment(**kwargs) - - create_from_edge_environment.__doc__ = IoTHubModuleClient.create_from_edge_environment.__doc__ - setattr( - IoTHubModuleClient, - "create_from_edge_environment", - classmethod(create_from_edge_environment), - ) - - def create_from_sastoken(cls, sastoken, **kwargs): - return super(IoTHubModuleClient, cls).create_from_sastoken(sastoken, **kwargs) - - create_from_sastoken.__doc__ = IoTHubModuleClient.create_from_sastoken.__doc__ - setattr(IoTHubModuleClient, "create_from_sastoken", classmethod(create_from_sastoken)) - - def create_from_x509_certificate(cls, x509, hostname, device_id, module_id, **kwargs): - return super(IoTHubModuleClient, cls).create_from_x509_certificate( - x509, hostname, device_id, module_id, **kwargs - ) - - create_from_x509_certificate.__doc__ = IoTHubModuleClient.create_from_x509_certificate.__doc__ - setattr( - IoTHubModuleClient, - "create_from_x509_certificate", - classmethod(create_from_x509_certificate), - ) - setattr(IoTHubModuleClient, "connected", IoTHubModuleClient.connected) - setattr( - IoTHubModuleClient, "on_background_exception", IoTHubModuleClient.on_background_exception - ) - setattr( - IoTHubModuleClient, - "on_connection_state_change", - IoTHubModuleClient.on_connection_state_change, - ) - setattr(IoTHubModuleClient, "on_message_received", IoTHubModuleClient.on_message_received) - setattr( - IoTHubModuleClient, - "on_method_request_received", - IoTHubModuleClient.on_method_request_received, - ) - setattr( - IoTHubModuleClient, "on_new_sastoken_required", IoTHubModuleClient.on_new_sastoken_required - ) - setattr( - IoTHubModuleClient, - "on_twin_desired_properties_patch_received", - IoTHubModuleClient.on_twin_desired_properties_patch_received, - ) - from azure.iot.device.provisioning.aio.async_provisioning_device_client import ( - ProvisioningDeviceClient as ProvisioningDeviceClient, - ) - - def create_from_symmetric_key( - cls, provisioning_host, registration_id, id_scope, symmetric_key, **kwargs - ): - return super(ProvisioningDeviceClient, cls).create_from_symmetric_key( - provisioning_host, registration_id, id_scope, symmetric_key, **kwargs - ) - - create_from_symmetric_key.__doc__ = ProvisioningDeviceClient.create_from_symmetric_key.__doc__ - setattr( - ProvisioningDeviceClient, - "create_from_symmetric_key", - classmethod(create_from_symmetric_key), - ) - - def create_from_x509_certificate( - cls, provisioning_host, registration_id, id_scope, x509, **kwargs - ): - return super(ProvisioningDeviceClient, cls).create_from_x509_certificate( - provisioning_host, registration_id, id_scope, x509, **kwargs - ) - - create_from_x509_certificate.__doc__ = ( - ProvisioningDeviceClient.create_from_x509_certificate.__doc__ - ) - setattr( - ProvisioningDeviceClient, - "create_from_x509_certificate", - classmethod(create_from_x509_certificate), - ) - setattr( - ProvisioningDeviceClient, - "provisioning_payload", - ProvisioningDeviceClient.provisioning_payload, - ) diff --git a/azure-iot-device/azure/iot/device/common/__init__.py b/azure-iot-device/azure/iot/device/common/__init__.py deleted file mode 100644 index 89a2ca3bb..000000000 --- a/azure-iot-device/azure/iot/device/common/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Azure IoT Device Common - -This package provides shared modules for use with various Azure IoT device-side clients. - -INTERNAL USAGE ONLY -""" - -from .models import X509, ProxyOptions - -__all__ = ["X509", "ProxyOptions"] diff --git a/azure-iot-device/azure/iot/device/common/alarm.py b/azure-iot-device/azure/iot/device/common/alarm.py deleted file mode 100644 index 154cd219c..000000000 --- a/azure-iot-device/azure/iot/device/common/alarm.py +++ /dev/null @@ -1,42 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from threading import Thread, Event -import time - -# NOTE: The Alarm class is very similar, but fundamentally different from threading.Timer. -# Beyond just the input format difference (a specific time for Alarm vs. an interval for Timer), -# the manner in which they keep time is different. A Timer will only tick towards its interval -# while the system is running, but an Alarm will go off at a specific time, so system sleep will -# not throw off the timekeeping. In the case that the Alarm time occurs while the system is asleep -# the Alarm will trigger upon system wake. - - -class Alarm(Thread): - """Call a function at a specified time""" - - def __init__(self, alarm_time, function, args=None, kwargs=None): - Thread.__init__(self) - self.alarm_time = alarm_time - self.function = function - self.args = args if args is not None else [] - self.kwargs = kwargs if kwargs is not None else {} - self.finished = Event() - - def cancel(self): - """Stop the alarm if it hasn't finished yet.""" - self.finished.set() - - def run(self): - """Method representing the thread's activity. - Overrides the method inherited from Thread. - Will invoke the Alarm's given function at the given alarm time (accurate within 1 second) - """ - while not self.finished.is_set() and time.time() < self.alarm_time: - self.finished.wait(1) - - if not self.finished.is_set(): - self.function(*self.args, **self.kwargs) - self.finished.set() diff --git a/azure-iot-device/azure/iot/device/common/async_adapter.py b/azure-iot-device/azure/iot/device/common/async_adapter.py deleted file mode 100644 index ca232822a..000000000 --- a/azure-iot-device/azure/iot/device/common/async_adapter.py +++ /dev/null @@ -1,91 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains tools for adapting sync code for use in async coroutines.""" - -import functools -import logging -import traceback -import azure.iot.device.common.asyncio_compat as asyncio_compat - -logger = logging.getLogger(__name__) - - -def emulate_async(fn): - """Returns a coroutine function that calls a given function with emulated asynchronous - behavior via use of multithreading. - - Can be applied as a decorator. - - :param fn: The sync function to be run in async. - :returns: A coroutine function that will call the given sync function. - """ - - @functools.wraps(fn) - async def async_fn_wrapper(*args, **kwargs): - loop = asyncio_compat.get_running_loop() - - # Run fn in default ThreadPoolExecutor (CPU * 5 threads) - return await loop.run_in_executor(None, functools.partial(fn, *args, **kwargs)) - - return async_fn_wrapper - - -class AwaitableCallback(object): - """A sync callback whose completion can be waited upon.""" - - def __init__(self, return_arg_name=None): - """Creates an instance of an AwaitableCallback""" - - # LBYL because this mistake doesn't cause an exception until the callback - # which is much later and very difficult to trace back to here. - if return_arg_name and not isinstance(return_arg_name, str): - raise TypeError("internal error: return_arg_name must be a string") - - loop = asyncio_compat.get_running_loop() - self.future = loop.create_future() - - def wrapping_callback(*args, **kwargs): - # Use event loop from outer scope, since the threads it will be used in will not have - # an event loop. future.set_result() and future.set_exception have to be called in an - # event loop or they do not work. - if "error" in kwargs and kwargs["error"]: - exception = kwargs["error"] - elif return_arg_name: - if return_arg_name in kwargs: - exception = None - result = kwargs[return_arg_name] - else: - raise TypeError( - "internal error: expected argument with name '{}', did not get".format( - return_arg_name - ) - ) - else: - exception = None - result = None - - if exception: - # Do not use exc_info parameter on logger.* calls. This causes pytest to save the traceback which saves stack frames which shows up as a leak - logger.info("Callback completed with error {}".format(exception)) - logger.info(traceback.format_exception_only(type(exception), exception)) - loop.call_soon_threadsafe(self.future.set_exception, exception) - else: - logger.debug("Callback completed with result {}".format(result)) - loop.call_soon_threadsafe(self.future.set_result, result) - - self.callback = wrapping_callback - - def __call__(self, *args, **kwargs): - """Calls the callback. Returns the result.""" - return self.callback(*args, **kwargs) - - async def completion(self): - """Awaitable coroutine method that will return once the AwaitableCallback - has been completed. - - :returns: Result of the callback when it was called. - """ - return await self.future diff --git a/azure-iot-device/azure/iot/device/common/asyncio_compat.py b/azure-iot-device/azure/iot/device/common/asyncio_compat.py deleted file mode 100644 index 0ab3f4b21..000000000 --- a/azure-iot-device/azure/iot/device/common/asyncio_compat.py +++ /dev/null @@ -1,61 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains compatibility tools for bridging different versions of asyncio""" - -import asyncio -import sys - - -def get_running_loop(): - """Gets the currently running event loop - - Uses asyncio.get_running_loop() if available (Python 3.7+) or a back-ported - version of the same function in 3.6. - """ - try: - loop = asyncio.get_running_loop() - except AttributeError: - loop = asyncio._get_running_loop() - if loop is None: - raise RuntimeError("no running event loop") - return loop - - -def create_task(coro): - """Creates a Task object. - - If available (Python 3.7+), use asyncio.create_task, which is preferred as it is - more specific for the goal of immediately scheduling a task from a coroutine. If - not available, use the more general purpose asyncio.ensure_future. - - :returns: A new Task object. - """ - try: - task = asyncio.create_task(coro) - except AttributeError: - task = asyncio.ensure_future(coro) - return task - - -def run(coro): - """Execute the coroutine coro and return the result. - - It creates a new event loop and closes it at the end. - Cannot be called when another asyncio event loop is running in the same thread. - - If available (Python 3.7+) use asyncio.run. If not available, use a custom implementation - that achieves the same thing - """ - if sys.version_info >= (3, 7): - return asyncio.run(coro) - else: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(coro) - finally: - loop.close() - asyncio.set_event_loop(None) diff --git a/azure-iot-device/azure/iot/device/common/auth/__init__.py b/azure-iot-device/azure/iot/device/common/auth/__init__.py deleted file mode 100644 index f58862451..000000000 --- a/azure-iot-device/azure/iot/device/common/auth/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .signing_mechanism import SymmetricKeySigningMechanism # noqa: F401 - -# NOTE: Please import the connection_string and sastoken modules directly -# rather than through the package interface, as the modules contain many -# related items for their respective domains, which we do not wish to expose -# at length here. diff --git a/azure-iot-device/azure/iot/device/common/auth/sastoken.py b/azure-iot-device/azure/iot/device/common/auth/sastoken.py deleted file mode 100644 index 171636431..000000000 --- a/azure-iot-device/azure/iot/device/common/auth/sastoken.py +++ /dev/null @@ -1,154 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains tools for working with Shared Access Signature (SAS) Tokens""" - -import time -import urllib - - -class SasTokenError(Exception): - """Error in SasToken""" - - pass - - -class RenewableSasToken(object): - """Renewable Shared Access Signature Token used to authenticate a request. - - This token is 'renewable', which means that it can be updated when necessary to - prevent expiry, by using the .refresh() method. - - Data Attributes: - expiry_time (int): Time that token will expire (in UTC, since epoch) - ttl (int): Time to live for the token, in seconds - """ - - _auth_rule_token_format = ( - "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}" - ) - _simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}" - - def __init__(self, uri, signing_mechanism, key_name=None, ttl=3600): - """ - :param str uri: URI of the resource to be accessed - :param signing_mechanism: The signing mechanism to use in the SasToken - :type signing_mechanism: Child classes of :class:`azure.iot.common.SigningMechanism` - :param str key_name: Symmetric Key Name (optional) - :param int ttl: Time to live for the token, in seconds (default 3600) - - :raises: SasTokenError if an error occurs building a SasToken - """ - self._uri = uri - self._signing_mechanism = signing_mechanism - self._key_name = key_name - self._expiry_time = None # This will be overwritten by the .refresh() call below - self._token = None # This will be overwritten by the .refresh() call below - - self.ttl = ttl - self.refresh() - - def __str__(self): - return self._token - - def refresh(self): - """ - Refresh the SasToken lifespan, giving it a new expiry time, and generating a new token. - """ - self._expiry_time = int(time.time() + self.ttl) - self._token = self._build_token() - - def _build_token(self): - """Build SasToken representation - - :returns: String representation of the token - """ - url_encoded_uri = urllib.parse.quote(self._uri, safe="") - message = url_encoded_uri + "\n" + str(self.expiry_time) - try: - signature = self._signing_mechanism.sign(message) - except Exception as e: - # Because of variant signing mechanisms, we don't know what error might be raised. - # So we catch all of them. - raise SasTokenError("Unable to build SasToken from given values") from e - url_encoded_signature = urllib.parse.quote(signature, safe="") - if self._key_name: - token = self._auth_rule_token_format.format( - resource=url_encoded_uri, - signature=url_encoded_signature, - expiry=str(self.expiry_time), - keyname=self._key_name, - ) - else: - token = self._simple_token_format.format( - resource=url_encoded_uri, - signature=url_encoded_signature, - expiry=str(self.expiry_time), - ) - return token - - @property - def expiry_time(self): - """Expiry Time is READ ONLY""" - return self._expiry_time - - -class NonRenewableSasToken(object): - """NonRenewable Shared Access Signature Token used to authenticate a request. - - This token is 'non-renewable', which means that it is invalid once it expires, and there - is no way to keep it alive. Instead, a new token must be created. - - Data Attributes: - expiry_time (int): Time that token will expire (in UTC, since epoch) - resource_uri (str): URI for the resource the Token provides authentication to access - """ - - def __init__(self, sastoken_string): - """ - :param str sastoken_string: A string representation of a SAS token - """ - self._token = sastoken_string - self._token_info = get_sastoken_info_from_string(self._token) - - def __str__(self): - return self._token - - @property - def expiry_time(self): - """Expiry Time is READ ONLY""" - return int(self._token_info["se"]) - - @property - def resource_uri(self): - """Resource URI is READ ONLY""" - uri = self._token_info["sr"] - return urllib.parse.unquote(uri) - - -REQUIRED_SASTOKEN_FIELDS = ["sr", "sig", "se"] -VALID_SASTOKEN_FIELDS = REQUIRED_SASTOKEN_FIELDS + ["skn"] - - -def get_sastoken_info_from_string(sastoken_string): - pieces = sastoken_string.split("SharedAccessSignature ") - if len(pieces) != 2: - raise SasTokenError("Invalid SasToken string: Not a SasToken ") - - # Get sastoken info as dictionary - try: - sastoken_info = dict(map(str.strip, sub.split("=", 1)) for sub in pieces[1].split("&")) - except Exception as e: - raise SasTokenError("Invalid SasToken string: Incorrectly formatted") from e - - # Validate that all required fields are present - if not all(key in sastoken_info for key in REQUIRED_SASTOKEN_FIELDS): - raise SasTokenError("Invalid SasToken string: Not all required fields present") - - # Validate that no unexpected fields are present - if not all(key in VALID_SASTOKEN_FIELDS for key in sastoken_info): - raise SasTokenError("Invalid SasToken string: Unexpected fields present") - - return sastoken_info diff --git a/azure-iot-device/azure/iot/device/common/evented_callback.py b/azure-iot-device/azure/iot/device/common/evented_callback.py deleted file mode 100644 index 0e8e997da..000000000 --- a/azure-iot-device/azure/iot/device/common/evented_callback.py +++ /dev/null @@ -1,71 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import threading -import logging -import traceback - -logger = logging.getLogger(__name__) - - -class EventedCallback(object): - """ - A sync callback whose completion can be waited upon. - """ - - def __init__(self, return_arg_name=None): - """ - Creates an instance of an EventedCallback. - - """ - # LBYL because this mistake doesn't cause an exception until the callback - # which is much later and very difficult to trace back to here. - if return_arg_name and not isinstance(return_arg_name, str): - raise TypeError("internal error: return_arg_name must be a string") - - self.completion_event = threading.Event() - self.exception = None - self.result = None - - def wrapping_callback(*args, **kwargs): - if "error" in kwargs and kwargs["error"]: - self.exception = kwargs["error"] - elif return_arg_name: - if return_arg_name in kwargs: - self.result = kwargs[return_arg_name] - else: - raise TypeError( - "internal error: expected argument with name '{}', did not get".format( - return_arg_name - ) - ) - - if self.exception: - # Do not use exc_info parameter on logger.* calls. This causes pytest to save the traceback which saves stack frames which shows up as a leak - logger.info("Callback completed with error {}".format(self.exception)) - logger.info(traceback.format_exc()) - else: - logger.debug("Callback completed with result {}".format(self.result)) - - self.completion_event.set() - - self.callback = wrapping_callback - - def __call__(self, *args, **kwargs): - """ - Calls the callback. - """ - self.callback(*args, **kwargs) - - def wait_for_completion(self, *args, **kwargs): - """ - Wait for the callback to be called, and return the results. - """ - self.completion_event.wait(*args, **kwargs) - - if self.exception: - raise self.exception - else: - return self.result diff --git a/azure-iot-device/azure/iot/device/common/handle_exceptions.py b/azure-iot-device/azure/iot/device/common/handle_exceptions.py deleted file mode 100644 index fdb5e89bc..000000000 --- a/azure-iot-device/azure/iot/device/common/handle_exceptions.py +++ /dev/null @@ -1,56 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import traceback - -logger = logging.getLogger(__name__) - - -def handle_background_exception(e): - """ - Function which handled exceptions that are caught in background thread. This is - typically called from the callback thread inside the pipeline. These exceptions - need special handling because callback functions are typically called inside a - non-application thread in response to non-user-initiated actions, so there's - nobody else to catch them. - - This function gets called from inside an arbitrary thread context, so code that - runs from this function should be limited to the bare minimum. - - :param Error e: Exception object raised from inside a background thread - """ - - # @FUTURE: We should add a mechanism which allows applications to receive these - # exceptions so they can respond accordingly - logger.warning(msg="Exception caught in background thread. Unable to handle.") - logger.warning(traceback.format_exception_only(type(e), e)) - - -def swallow_unraised_exception(e, log_msg=None, log_lvl="warning"): - """Swallow and log an exception object. - - Convenience function for logging, as exceptions can only be logged correctly from within a - except block. - - :param Exception e: Exception object to be swallowed. - :param str log_msg: Optional message to use when logging. - :param str log_lvl: The log level to use for logging. Default "warning". - """ - try: - raise e - except Exception: - if log_lvl == "warning": - logger.warning(log_msg) - logger.warning(traceback.format_exc()) - elif log_lvl == "error": - logger.error(log_msg) - logger.error(traceback.format_exc()) - elif log_lvl == "info": - logger.info(log_msg) - logger.info(traceback.format_exc()) - else: - logger.debug(log_msg) - logger.debug(traceback.format_exc()) diff --git a/azure-iot-device/azure/iot/device/common/http_transport.py b/azure-iot-device/azure/iot/device/common/http_transport.py deleted file mode 100644 index ab1cbfbdc..000000000 --- a/azure-iot-device/azure/iot/device/common/http_transport.py +++ /dev/null @@ -1,210 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import ssl -import requests -from . import transport_exceptions as exceptions -from .pipeline import pipeline_thread - -logger = logging.getLogger(__name__) - - -# NOTE: There should probably be a more global timeout configuration, but for now this will do. -HTTP_TIMEOUT = 10 - - -class HTTPTransport(object): - """ - A wrapper class that provides an implementation-agnostic HTTP interface. - """ - - def __init__( - self, - hostname, - server_verification_cert=None, - x509_cert=None, - cipher=None, - proxy_options=None, - ): - """ - Constructor to instantiate an HTTP protocol wrapper. - - :param str hostname: Hostname or IP address of the remote host. - :param str server_verification_cert: Certificate which can be used to validate a server-side TLS connection (optional). - :param str cipher: Cipher string in OpenSSL cipher list format (optional) - :param x509_cert: Certificate which can be used to authenticate connection to a server in lieu of a password (optional). - :param proxy_options: Options for sending traffic through proxy servers. - """ - self._hostname = hostname - self._server_verification_cert = server_verification_cert - self._x509_cert = x509_cert - self._cipher = cipher - self._proxies = format_proxies(proxy_options) - self._http_adapter = self._create_http_adapter() - - def _create_http_adapter(self): - """ - This method creates a custom HTTPAdapter for use with a requests library session. - It will allow for use of a custom configured SSL context. - """ - ssl_context = self._create_ssl_context() - - class CustomSSLContextHTTPAdapter(requests.adapters.HTTPAdapter): - def init_poolmanager(self, *args, **kwargs): - kwargs["ssl_context"] = ssl_context - return super().init_poolmanager(*args, **kwargs) - - def proxy_manager_for(self, *args, **kwargs): - kwargs["ssl_context"] = ssl_context - return super().proxy_manager_for(*args, **kwargs) - - return CustomSSLContextHTTPAdapter() - - def _create_ssl_context(self): - """ - This method creates the SSLContext object used to authenticate the connection. The generated context is used by the http_client and is necessary when authenticating using a self-signed X509 cert or trusted X509 cert - """ - logger.debug("creating a SSL context") - ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLSv1_2) - - if self._server_verification_cert: - ssl_context.load_verify_locations(cadata=self._server_verification_cert) - else: - ssl_context.load_default_certs() - - if self._cipher: - try: - ssl_context.set_ciphers(self._cipher) - except ssl.SSLError as e: - # TODO: custom error with more detail? - raise e - - if self._x509_cert is not None: - logger.debug("configuring SSL context with client-side certificate and key") - ssl_context.load_cert_chain( - self._x509_cert.certificate_file, - self._x509_cert.key_file, - self._x509_cert.pass_phrase, - ) - - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True - - return ssl_context - - @pipeline_thread.invoke_on_http_thread_nowait - def request(self, method, path, callback, body="", headers={}, query_params=""): - """ - This method creates a connection to a remote host, sends a request to that host, and then waits for and reads the response from that request. - - :param str method: The request method (e.g. "POST") - :param str path: The path for the URL - :param Function callback: The function that gets called when this operation is complete or has failed. The callback function must accept an error and a response dictionary, where the response dictionary contains a status code, a reason, and a response string. - :param str body: The body of the HTTP request to be sent following the headers. - :param dict headers: A dictionary that provides extra HTTP headers to be sent with the request. - :param str query_params: The optional query parameters to be appended at the end of the URL. - """ - # Sends a complete request to the server - logger.info("sending https {} request to {} .".format(method, path)) - - # Mount the transport adapter to a requests session - session = requests.Session() - session.mount("https://", self._http_adapter) - - # Format request URL - # TODO: URL formation should be moved to pipeline_stages_iothub_http, I believe, as - # depending on the operation this could have a different hostname, due to different - # destinations. For now this isn't a problem yet, because no possible client can - # support more than one HTTP operation - # (Device can do File Upload but NOT Method Invoke, Module can do Method Invoke and NOT file upload) - url = "https://{hostname}/{path}{query_params}".format( - hostname=self._hostname, - path=path, - query_params="?" + query_params if query_params else "", - ) - - try: - # Note that various configuration options are not set here due to them being set - # via the HTTPAdapter that was mounted at session level. - if method == "GET": - response = session.get( - url, data=body, headers=headers, proxies=self._proxies, timeout=HTTP_TIMEOUT - ) - elif method == "POST": - response = session.post( - url, data=body, headers=headers, proxies=self._proxies, timeout=HTTP_TIMEOUT - ) - elif method == "PUT": - response = session.put( - url, data=body, headers=headers, proxies=self._proxies, timeout=HTTP_TIMEOUT - ) - elif method == "PATCH": - response = session.patch( - url, data=body, headers=headers, proxies=self._proxies, timeout=HTTP_TIMEOUT - ) - elif method == "DELETE": - response = session.delete( - url, data=body, headers=headers, proxies=self._proxies, timeout=HTTP_TIMEOUT - ) - else: - raise ValueError("Invalid method type: {}".format(method)) - except ValueError as e: - # Allow ValueError to propagate - callback(error=e) - except requests.exceptions.Timeout as e: - # Allow Timeout to propagate - # NOTE: This breaks the convention in transports where we don't expose anything - # but builtin exceptions and the exceptions defined in transport_exceptions.py. - # However, we don't exactly have infrastructure to support timeout at Transport level. - # For now, just expose it, and if/when we more broadly support timeout, this can change - callback(error=e) - except Exception as e: - # Raise error via the callback - new_err = exceptions.ProtocolClientError("Unexpected HTTPS failure during connect") - new_err.__cause__ = e - callback(error=new_err) - else: - # Return the data from the response via the callback - response_obj = { - "status_code": response.status_code, - "reason": response.reason, - "resp": response.text, - } - callback(response=response_obj) - - -def format_proxies(proxy_options): - """ - Format the data from the proxy_options object into a format for use with the requests library - """ - proxies = {} - if proxy_options: - # Basic address/port formatting - proxy = "{address}:{port}".format( - address=proxy_options.proxy_address, port=proxy_options.proxy_port - ) - # Add credentials if necessary - if proxy_options.proxy_username and proxy_options.proxy_password: - auth = "{username}:{password}".format( - username=proxy_options.proxy_username, password=proxy_options.proxy_password - ) - proxy = auth + "@" + proxy - # Set proxy for use on HTTP or HTTPS connections - if proxy_options.proxy_type == "HTTP": - proxies["http"] = "http://" + proxy - proxies["https"] = "http://" + proxy - elif proxy_options.proxy_type == "SOCKS4": - proxies["http"] = "socks4://" + proxy - proxies["https"] = "socks4://" + proxy - elif proxy_options.proxy_type == "SOCKS5": - proxies["http"] = "socks5://" + proxy - proxies["https"] = "socks5://" + proxy - else: - # This should be unreachable due to validation on the ProxyOptions object - raise ValueError("Invalid proxy type: {}".format(proxy_options.proxy_type)) - - return proxies diff --git a/azure-iot-device/azure/iot/device/common/models/__init__.py b/azure-iot-device/azure/iot/device/common/models/__init__.py deleted file mode 100644 index b5d479ae9..000000000 --- a/azure-iot-device/azure/iot/device/common/models/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Azure Device Models - -This package provides object models for use within the Azure Provisioning Device SDK and Azure IoTHub Device SDK. -""" - -from .x509 import X509 # noqa: F401 -from .proxy_options import ProxyOptions # noqa: F401 diff --git a/azure-iot-device/azure/iot/device/common/models/proxy_options.py b/azure-iot-device/azure/iot/device/common/models/proxy_options.py deleted file mode 100644 index b94ded075..000000000 --- a/azure-iot-device/azure/iot/device/common/models/proxy_options.py +++ /dev/null @@ -1,74 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -This module represents proxy options to enable sending traffic through proxy servers. -""" -import socks - -string_to_socks_constant_map = {"HTTP": socks.HTTP, "SOCKS4": socks.SOCKS4, "SOCKS5": socks.SOCKS5} - -socks_constant_to_string_map = {socks.HTTP: "HTTP", socks.SOCKS4: "SOCKS4", socks.SOCKS5: "SOCKS5"} - - -class ProxyOptions(object): - """ - A class containing various options to send traffic through proxy servers by enabling - proxying of MQTT connection. - """ - - def __init__( - self, proxy_type, proxy_addr, proxy_port, proxy_username=None, proxy_password=None - ): - """ - Initializer for proxy options. - :param str proxy_type: The type of the proxy server. This can be one of three possible choices: "HTTP", "SOCKS4", or "SOCKS5" - :param str proxy_addr: IP address or DNS name of proxy server - :param int proxy_port: The port of the proxy server. Defaults to 1080 for socks and 8080 for http. - :param str proxy_username: (optional) username for SOCKS5 proxy, or userid for SOCKS4 proxy.This parameter is ignored if an HTTP server is being used. - If it is not provided, authentication will not be used (servers may accept unauthenticated requests). - :param str proxy_password: (optional) This parameter is valid only for SOCKS5 servers and specifies the respective password for the username provided. - """ - (self._proxy_type, self._proxy_type_socks) = format_proxy_type(proxy_type) - self._proxy_addr = proxy_addr - self._proxy_port = int(proxy_port) - self._proxy_username = proxy_username - self._proxy_password = proxy_password - - @property - def proxy_type(self): - return self._proxy_type - - @property - def proxy_type_socks(self): - return self._proxy_type_socks - - @property - def proxy_address(self): - return self._proxy_addr - - @property - def proxy_port(self): - return self._proxy_port - - @property - def proxy_username(self): - return self._proxy_username - - @property - def proxy_password(self): - return self._proxy_password - - -def format_proxy_type(proxy_type): - """Returns a tuple of formats for proxy type (string, socks library constant)""" - try: - return (proxy_type, string_to_socks_constant_map[proxy_type]) - except KeyError: - # Backwards compatibility for when we used the socks library constants in the API - try: - return (socks_constant_to_string_map[proxy_type], proxy_type) - except KeyError: - raise ValueError("Invalid Proxy Type") diff --git a/azure-iot-device/azure/iot/device/common/models/x509.py b/azure-iot-device/azure/iot/device/common/models/x509.py deleted file mode 100644 index 8f5fdb82b..000000000 --- a/azure-iot-device/azure/iot/device/common/models/x509.py +++ /dev/null @@ -1,39 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module represents a certificate that is responsible for providing client provided x509 certificates -that will eventually establish the authenticity of devices to IoTHub and Provisioning Services. -""" - - -class X509(object): - """ - A class with references to the certificate, key, and optional pass-phrase used to authenticate - a TLS connection using x509 certificates - """ - - def __init__(self, cert_file, key_file, pass_phrase=None): - """ - Initializer for X509 Certificate - :param cert_file: The file path to contents of the certificate (or certificate chain) - used to authenticate the device. - :param key_file: The file path to the key associated with the certificate - :param pass_phrase: (optional) The pass_phrase used to encode the key file - """ - self._cert_file = cert_file - self._key_file = key_file - self._pass_phrase = pass_phrase - - @property - def certificate_file(self): - return self._cert_file - - @property - def key_file(self): - return self._key_file - - @property - def pass_phrase(self): - return self._pass_phrase diff --git a/azure-iot-device/azure/iot/device/common/mqtt_transport.py b/azure-iot-device/azure/iot/device/common/mqtt_transport.py deleted file mode 100644 index cf54a0416..000000000 --- a/azure-iot-device/azure/iot/device/common/mqtt_transport.py +++ /dev/null @@ -1,682 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import paho.mqtt.client as mqtt -import logging -import ssl -import threading -import traceback -import weakref -import socket -from . import transport_exceptions as exceptions -import socks - -logger = logging.getLogger(__name__) - -# Mapping of Paho CONNACK rc codes to Error object classes -# Used for connection callbacks -paho_connack_rc_to_error = { - mqtt.CONNACK_REFUSED_PROTOCOL_VERSION: exceptions.ProtocolClientError, - mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED: exceptions.ProtocolClientError, - mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE: exceptions.ConnectionFailedError, - mqtt.CONNACK_REFUSED_BAD_USERNAME_PASSWORD: exceptions.UnauthorizedError, - mqtt.CONNACK_REFUSED_NOT_AUTHORIZED: exceptions.UnauthorizedError, -} - -# Mapping of Paho rc codes to Error object classes -# Used for responses to Paho APIs and non-connection callbacks -paho_rc_to_error = { - mqtt.MQTT_ERR_NOMEM: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_PROTOCOL: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_INVAL: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_NO_CONN: exceptions.NoConnectionError, - mqtt.MQTT_ERR_CONN_REFUSED: exceptions.ConnectionFailedError, - mqtt.MQTT_ERR_NOT_FOUND: exceptions.ConnectionFailedError, - mqtt.MQTT_ERR_CONN_LOST: exceptions.ConnectionDroppedError, - mqtt.MQTT_ERR_TLS: exceptions.UnauthorizedError, - mqtt.MQTT_ERR_PAYLOAD_SIZE: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_NOT_SUPPORTED: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_AUTH: exceptions.UnauthorizedError, - mqtt.MQTT_ERR_ACL_DENIED: exceptions.UnauthorizedError, - mqtt.MQTT_ERR_UNKNOWN: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_ERRNO: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_QUEUE_SIZE: exceptions.ProtocolClientError, - mqtt.MQTT_ERR_KEEPALIVE: exceptions.ConnectionDroppedError, -} - - -def _create_error_from_connack_rc_code(rc): - """ - Given a paho CONNACK rc code, return an Exception that can be raised - """ - message = mqtt.connack_string(rc) - if rc in paho_connack_rc_to_error: - return paho_connack_rc_to_error[rc](message) - else: - return exceptions.ProtocolClientError("Unknown CONNACK rc={}".format(rc)) - - -def _create_error_from_rc_code(rc): - """ - Given a paho rc code, return an Exception that can be raised - """ - if rc == 1: - # Paho returns rc=1 to mean "something went wrong. stop". We manually translate this to a ConnectionDroppedError. - return exceptions.ConnectionDroppedError("Paho returned rc==1") - elif rc in paho_rc_to_error: - message = mqtt.error_string(rc) - return paho_rc_to_error[rc](message) - else: - return exceptions.ProtocolClientError("Unknown rc=={}".format(rc)) - - -class MQTTTransport(object): - """ - A wrapper class that provides an implementation-agnostic MQTT message broker interface. - - :ivar on_mqtt_connected_handler: Event handler callback, called upon establishing a connection. - :type on_mqtt_connected_handler: Function - :ivar on_mqtt_disconnected_handler: Event handler callback, called upon a disconnection. - :type on_mqtt_disconnected_handler: Function - :ivar on_mqtt_message_received_handler: Event handler callback, called upon receiving a message. - :type on_mqtt_message_received_handler: Function - :ivar on_mqtt_connection_failure_handler: Event handler callback, called upon a connection failure. - :type on_mqtt_connection_failure_handler: Function - """ - - def __init__( - self, - client_id, - hostname, - username, - server_verification_cert=None, - x509_cert=None, - websockets=False, - cipher=None, - proxy_options=None, - keep_alive=None, - ): - """ - Constructor to instantiate an MQTT protocol wrapper. - :param str client_id: The id of the client connecting to the broker. - :param str hostname: Hostname or IP address of the remote broker. - :param str username: Username for login to the remote broker. - :param str server_verification_cert: Certificate which can be used to validate a server-side TLS connection (optional). - :param x509_cert: Certificate which can be used to authenticate connection to a server in lieu of a password (optional). - :param bool websockets: Indicates whether or not to enable a websockets connection in the Transport. - :param str cipher: Cipher string in OpenSSL cipher list format - :param proxy_options: Options for sending traffic through proxy servers. - """ - self._client_id = client_id - self._hostname = hostname - self._username = username - self._mqtt_client = None - self._server_verification_cert = server_verification_cert - self._x509_cert = x509_cert - self._websockets = websockets - self._cipher = cipher - self._proxy_options = proxy_options - self._keep_alive = keep_alive - - self.on_mqtt_connected_handler = None - self.on_mqtt_disconnected_handler = None - self.on_mqtt_message_received_handler = None - self.on_mqtt_connection_failure_handler = None - - self._op_manager = OperationManager() - - self._mqtt_client = self._create_mqtt_client() - - def _create_mqtt_client(self): - """ - Create the MQTT client object and assign all necessary event handler callbacks. - """ - logger.debug("creating mqtt client") - - # Instantiate the client - if self._websockets: - logger.info("Creating client for connecting using MQTT over websockets") - mqtt_client = mqtt.Client( - client_id=self._client_id, - clean_session=False, - protocol=mqtt.MQTTv311, - transport="websockets", - ) - mqtt_client.ws_set_options(path="/$iothub/websocket") - else: - logger.info("Creating client for connecting using MQTT over TCP") - mqtt_client = mqtt.Client( - client_id=self._client_id, clean_session=False, protocol=mqtt.MQTTv311 - ) - - if self._proxy_options: - logger.info("Setting custom proxy options on mqtt client") - mqtt_client.proxy_set( - proxy_type=self._proxy_options.proxy_type_socks, - proxy_addr=self._proxy_options.proxy_address, - proxy_port=self._proxy_options.proxy_port, - proxy_username=self._proxy_options.proxy_username, - proxy_password=self._proxy_options.proxy_password, - ) - - mqtt_client.enable_logger(logging.getLogger("paho")) - - # Configure TLS/SSL - ssl_context = self._create_ssl_context() - mqtt_client.tls_set_context(context=ssl_context) - - # Set event handlers. Use weak references back into this object to prevent leaks - self_weakref = weakref.ref(self) - - def on_connect(client, userdata, flags, rc): - this = self_weakref() - logger.info("connected with result code: {}".format(rc)) - - if rc: # i.e. if there is an error - if this.on_mqtt_connection_failure_handler: - try: - this.on_mqtt_connection_failure_handler( - _create_error_from_connack_rc_code(rc) - ) - except Exception: - logger.warning( - "Unexpected error calling on_mqtt_connection_failure_handler" - ) - logger.warning(traceback.format_exc()) - else: - logger.warning( - "connection failed, but no on_mqtt_connection_failure_handler handler callback provided" - ) - elif this.on_mqtt_connected_handler: - try: - this.on_mqtt_connected_handler() - except Exception: - logger.warning("Unexpected error calling on_mqtt_connected_handler") - logger.warning(traceback.format_exc()) - else: - logger.debug("No event handler callback set for on_mqtt_connected_handler") - - def on_disconnect(client, userdata, rc): - this = self_weakref() - logger.info("disconnected with result code: {}".format(rc)) - - cause = None - if rc: # i.e. if there is an error - logger.debug("".join(traceback.format_stack())) - cause = _create_error_from_rc_code(rc) - if this: - this._force_transport_disconnect_and_cleanup() - - if not this: - # Paho will sometimes call this after we've been garbage collected, If so, we have to - # stop the loop to make sure the Paho thread shuts down. - logger.info( - "on_disconnect called with transport==None. Transport must have been garbage collected. stopping loop" - ) - client.loop_stop() - else: - if this.on_mqtt_disconnected_handler: - try: - this.on_mqtt_disconnected_handler(cause) - except Exception: - logger.warning("Unexpected error calling on_mqtt_disconnected_handler") - logger.warning(traceback.format_exc()) - else: - logger.warning("No event handler callback set for on_mqtt_disconnected_handler") - - def on_subscribe(client, userdata, mid, granted_qos): - this = self_weakref() - logger.info("suback received for {}".format(mid)) - # subscribe failures are returned from the subscribe() call. This is just - # a notification that a SUBACK was received, so there is no failure case here - this._op_manager.complete_operation(mid) - - def on_unsubscribe(client, userdata, mid): - this = self_weakref() - logger.info("UNSUBACK received for {}".format(mid)) - # unsubscribe failures are returned from the unsubscribe() call. This is just - # a notification that a SUBACK was received, so there is no failure case here - this._op_manager.complete_operation(mid) - - def on_publish(client, userdata, mid): - this = self_weakref() - logger.info("payload published for {}".format(mid)) - # publish failures are returned from the publish() call. This is just - # a notification that a PUBACK was received, so there is no failure case here - this._op_manager.complete_operation(mid) - - def on_message(client, userdata, mqtt_message): - this = self_weakref() - logger.info("message received on {}".format(mqtt_message.topic)) - - if this.on_mqtt_message_received_handler: - try: - this.on_mqtt_message_received_handler(mqtt_message.topic, mqtt_message.payload) - except Exception: - logger.warning("Unexpected error calling on_mqtt_message_received_handler") - logger.warning(traceback.format_exc()) - else: - logger.debug( - "No event handler callback set for on_mqtt_message_received_handler - DROPPING MESSAGE" - ) - - mqtt_client.on_connect = on_connect - mqtt_client.on_disconnect = on_disconnect - mqtt_client.on_subscribe = on_subscribe - mqtt_client.on_unsubscribe = on_unsubscribe - mqtt_client.on_publish = on_publish - mqtt_client.on_message = on_message - - # Set paho automatic-reconnect delay to 2 hours. Ideally we would turn - # paho auto-reconnect off entirely, but this is the best we can do. Without - # this, we run the risk of our auto-reconnect code and the paho auto-reconnect - # code conflicting with each other. - # The choice of 2 hours is completely arbitrary - mqtt_client.reconnect_delay_set(120 * 60) - - logger.debug("Created MQTT protocol client, assigned callbacks") - return mqtt_client - - def _force_transport_disconnect_and_cleanup(self): - """ - After disconnecting because of an error, Paho was designed to keep the loop running and - to try reconnecting after the reconnect interval. We don't want Paho to reconnect because - we want to control the timing of the reconnect, so we force the loop to stop. - - We are relying on intimate knowledge of Paho behavior here. If this becomes a problem, - it may be necessary to write our own Paho thread and stop using thread_start()/thread_stop(). - This is certainly supported by Paho, but the thread that Paho provides works well enough - (so far) and making our own would be more complex than is currently justified. - """ - - logger.info("Forcing paho disconnect to prevent it from automatically reconnecting") - - # Note: We are calling this inside our on_disconnect() handler, so we might be inside the - # Paho thread at this point. This is perfectly valid. Comments in Paho's client.py - # loop_forever() function re-comment calling disconnect() from a callback to exit the - # Paho thread/loop. - - self._mqtt_client.disconnect() - - # Calling disconnect() isn't enough. We also need to call loop_stop to make sure - # Paho is as clean as possible. Our call to disconnect() above is enough to stop the - # loop and exit the tread, but the call to loop_stop() is necessary to complete the cleanup. - - self._mqtt_client.loop_stop() - - # Finally, because of a bug in Paho, we need to null out the _thread pointer. This - # is necessary because the code that sets _thread to None only gets called if you - # call loop_stop from an external thread (and we're still inside the Paho thread here). - if threading.current_thread() == self._mqtt_client._thread: - logger.debug("in paho thread. nulling _thread") - self._mqtt_client._thread = None - - logger.debug("Done forcing paho disconnect") - - def _create_ssl_context(self): - """ - This method creates the SSLContext object used by Paho to authenticate the connection. - """ - logger.debug("creating a SSL context") - ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLSv1_2) - - if self._server_verification_cert: - logger.debug("configuring SSL context with custom server verification cert") - ssl_context.load_verify_locations(cadata=self._server_verification_cert) - else: - logger.debug("configuring SSL context with default certs") - ssl_context.load_default_certs() - - if self._cipher: - try: - logger.debug("configuring SSL context with cipher suites") - ssl_context.set_ciphers(self._cipher) - except ssl.SSLError as e: - # TODO: custom error with more detail? - raise e - - if self._x509_cert is not None: - logger.debug("configuring SSL context with client-side certificate and key") - ssl_context.load_cert_chain( - self._x509_cert.certificate_file, - self._x509_cert.key_file, - self._x509_cert.pass_phrase, - ) - - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True - - return ssl_context - - def shutdown(self): - """Shut down the transport. This is (currently) irreversible.""" - # Remove the disconnect handler from Paho. We don't want to trigger any events in response - # to the shutdown and confuse the higher level layers of code. Just end it. - self._mqtt_client.on_disconnect = None - # Now disconnect and do some additional cleanup. - self._force_transport_disconnect_and_cleanup() - self._op_manager.cancel_all_operations() - - def connect(self, password=None): - """ - Connect to the MQTT broker, using hostname and username set at instantiation. - - This method should be called as an entry point before sending any telemetry. - - The password is not required if the transport was instantiated with an x509 certificate. - - If MQTT connection has been proxied, connection will take a bit longer to allow negotiation - with the proxy server. Any errors in the proxy connection process will trigger exceptions - - :param str password: The password for connecting with the MQTT broker (Optional). - - :raises: ConnectionFailedError if connection could not be established. - :raises: ConnectionDroppedError if connection is dropped during execution. - :raises: UnauthorizedError if there is an error authenticating. - :raises: NoConnectionError in certain failure scenarios where a connection could not be established - :raises: ProtocolClientError if there is some other client error. - :raises: TlsExchangeAuthError if there a failure with TLS certificate exchange - :raises: ProtocolProxyError if there is a proxy-specific error - """ - logger.debug("connecting to mqtt broker") - - self._mqtt_client.username_pw_set(username=self._username, password=password) - - try: - if self._websockets: - logger.info("Connect using port 443 (websockets)") - rc = self._mqtt_client.connect( - host=self._hostname, port=443, keepalive=self._keep_alive - ) - else: - logger.info("Connect using port 8883 (TCP)") - rc = self._mqtt_client.connect( - host=self._hostname, port=8883, keepalive=self._keep_alive - ) - except socket.error as e: - self._force_transport_disconnect_and_cleanup() - - # Only this type will raise a special error - # To stop it from retrying. - if ( - isinstance(e, ssl.SSLError) - and e.strerror is not None - and "CERTIFICATE_VERIFY_FAILED" in e.strerror - ): - raise exceptions.TlsExchangeAuthError() from e - elif isinstance(e, socks.ProxyError): - if isinstance(e, socks.SOCKS5AuthError): - # TODO This is the only I felt like specializing - raise exceptions.UnauthorizedError() from e - else: - raise exceptions.ProtocolProxyError() from e - else: - # If the socket can't open (e.g. using iptables REJECT), we get a - # socket.error. Convert this into ConnectionFailedError so we can retry - raise exceptions.ConnectionFailedError() from e - - except Exception as e: - self._force_transport_disconnect_and_cleanup() - - raise exceptions.ProtocolClientError("Unexpected Paho failure during connect") from e - - logger.debug("_mqtt_client.connect returned rc={}".format(rc)) - if rc: - raise _create_error_from_rc_code(rc) - self._mqtt_client.loop_start() - - def disconnect(self, clear_inflight=False): - """ - Disconnect from the MQTT broker. - - :raises: ProtocolClientError if there is some client error. - :raises: ConnectionDroppedError in unexpected cases. - :raises: UnauthorizedError in unexpected cases. - :raises: ConnectionFailedError in unexpected cases. - :raises: NoConnectionError if the client isn't actually connected. - """ - logger.info("disconnecting MQTT client") - try: - rc = self._mqtt_client.disconnect() - except Exception as e: - raise exceptions.ProtocolClientError("Unexpected Paho failure during disconnect") from e - finally: - self._mqtt_client.loop_stop() - - if threading.current_thread() == self._mqtt_client._thread: - logger.debug("in paho thread. nulling _thread") - self._mqtt_client._thread = None - - logger.debug("_mqtt_client.disconnect returned rc={}".format(rc)) - if rc: - # This could result in ConnectionDroppedError or ProtocolClientError - # No matter what, we always raise here to give upper layers a chance to respond - # to this error. - err = _create_error_from_rc_code(rc) - raise err - else: - # Clear pending ops if instructed, but only if the disconnect was successful. - # Technically the disconnect could still fail upon response, however that would then - # cause a force disconnect via the on_disconnect handler, thus it is safe to clear - # ops here and now. - if clear_inflight: - self._op_manager.cancel_all_operations() - - def subscribe(self, topic, qos=1, callback=None): - """ - This method subscribes the client to one topic from the MQTT broker. - - :param str topic: a single string specifying the subscription topic to subscribe to - :param int qos: the desired quality of service level for the subscription. Defaults to 1. - :param callback: A callback to be triggered upon completion (Optional). - - :raises: ValueError if qos is not 0, 1 or 2. - :raises: ValueError if topic is None or has zero string length. - :raises: ConnectionDroppedError if connection is dropped during execution. - :raises: ProtocolClientError if there is some other client error. - :raises: NoConnectionError if the client isn't actually connected. - """ - logger.info("subscribing to {} with qos {}".format(topic, qos)) - try: - (rc, mid) = self._mqtt_client.subscribe(topic, qos=qos) - except ValueError: - raise - except Exception as e: - raise exceptions.ProtocolClientError("Unexpected Paho failure during subscribe") from e - logger.debug("_mqtt_client.subscribe returned rc={}".format(rc)) - if rc: - # This could result in ConnectionDroppedError or ProtocolClientError - raise _create_error_from_rc_code(rc) - self._op_manager.establish_operation(mid, callback) - - def unsubscribe(self, topic, callback=None): - """ - Unsubscribe the client from one topic on the MQTT broker. - - :param str topic: a single string which is the subscription topic to unsubscribe from. - :param callback: A callback to be triggered upon completion (Optional). - - :raises: ValueError if topic is None or has zero string length. - :raises: ConnectionDroppedError if connection is dropped during execution. - :raises: ProtocolClientError if there is some other client error. - :raises: NoConnectionError if the client isn't actually connected. - """ - logger.info("unsubscribing from {}".format(topic)) - try: - (rc, mid) = self._mqtt_client.unsubscribe(topic) - except ValueError: - raise - except Exception as e: - raise exceptions.ProtocolClientError( - "Unexpected Paho failure during unsubscribe" - ) from e - logger.debug("_mqtt_client.unsubscribe returned rc={}".format(rc)) - if rc: - # This could result in ConnectionDroppedError or ProtocolClientError - raise _create_error_from_rc_code(rc) - self._op_manager.establish_operation(mid, callback) - - def publish(self, topic, payload, qos=1, callback=None): - """ - Send a message via the MQTT broker. - - :param str topic: topic: The topic that the message should be published on. - :param payload: The actual message to send. - :type payload: str, bytes, int, float or None - :param int qos: the desired quality of service level for the subscription. Defaults to 1. - :param callback: A callback to be triggered upon completion (Optional). - - :raises: ValueError if qos is not 0, 1 or 2 - :raises: ValueError if topic is None or has zero string length - :raises: ValueError if topic contains a wildcard ("+") - :raises: ValueError if the length of the payload is greater than 268435455 bytes - :raises: TypeError if payload is not a valid type - :raises: ConnectionDroppedError if connection is dropped during execution. - :raises: ProtocolClientError if there is some other client error. - :raises: NoConnectionError if the client isn't actually connected. - """ - logger.info("publishing on {}".format(topic)) - try: - (rc, mid) = self._mqtt_client.publish(topic=topic, payload=payload, qos=qos) - except ValueError: - raise - except TypeError: - raise - except Exception as e: - raise exceptions.ProtocolClientError("Unexpected Paho failure during publish") from e - logger.debug("_mqtt_client.publish returned rc={}".format(rc)) - if rc: - # This could result in ConnectionDroppedError or ProtocolClientError - raise _create_error_from_rc_code(rc) - self._op_manager.establish_operation(mid, callback) - - -class OperationManager(object): - """Tracks pending operations and their associated callbacks until completion.""" - - def __init__(self): - # Maps mid->callback for operations where a request has been sent - # but the response has not yet been received - self._pending_operation_callbacks = {} - - # Maps mid->mid for responses received that are NOT established in the _pending_operation_callbacks dict. - # Necessary because sometimes an operation will complete with a response before the - # Paho call returns. - # TODO: make this map mid to something more useful (result code?) - self._unknown_operation_completions = {} - - self._lock = threading.Lock() - - def establish_operation(self, mid, callback=None): - """Establish a pending operation identified by MID, and store its completion callback. - - If the operation has already been completed, the callback will be triggered. - """ - trigger_callback = False - - with self._lock: - # Check to see if a response was already received for this MID before this method was - # able to be called due to threading shenanigans - if mid in self._unknown_operation_completions: - - # Clear the recorded unknown response now that it has been resolved - del self._unknown_operation_completions[mid] - - # Since the operation has already completed, indicate callback should trigger - trigger_callback = True - - else: - # Store the operation as pending, along with callback - self._pending_operation_callbacks[mid] = callback - logger.debug("Waiting for response on MID: {}".format(mid)) - - # Now that the lock has been released, if the callback should be triggered, - # go ahead and trigger it now. - if trigger_callback: - logger.debug( - "Response for MID: {} was received early - triggering callback".format(mid) - ) - if callback: - try: - callback() - except Exception: - logger.debug("Unexpected error calling callback for MID: {}".format(mid)) - logger.debug(traceback.format_exc()) - else: - # Not entirely unexpected because of QOS=1 - logger.debug("No callback for MID: {}".format(mid)) - - def complete_operation(self, mid): - """Complete an operation identified by MID and trigger the associated completion callback. - - If the operation MID is unknown, the completion status will be stored until - the operation is established. - """ - callback = None - trigger_callback = False - - with self._lock: - # If the mid is associated with an established pending operation, trigger the associated callback - if mid in self._pending_operation_callbacks: - - # Retrieve the callback, and clear the pending operation now that it has been completed - callback = self._pending_operation_callbacks[mid] - del self._pending_operation_callbacks[mid] - - # Since the operation is complete, indicate the callback should be triggered - trigger_callback = True - - else: - # Otherwise, store the mid as an unknown response - logger.debug("Response received for unknown MID: {}".format(mid)) - self._unknown_operation_completions[ - mid - ] = mid # TODO: set something more useful here - - # Now that the lock has been released, if the callback should be triggered, - # go ahead and trigger it now. - if trigger_callback: - logger.debug( - "Response received for recognized MID: {} - triggering callback".format(mid) - ) - if callback: - try: - callback() - except Exception: - logger.debug("Unexpected error calling callback for MID: {}".format(mid)) - logger.debug(traceback.format_exc()) - else: - # fully expected. QOS=1 means we might get 2 PUBACKs - logger.debug("No callback set for MID: {}".format(mid)) - - def cancel_all_operations(self): - """Complete all pending operations with cancellation, removing MID tracking""" - logger.debug("Cancelling all pending operations") - with self._lock: - # Clear pending operations - pending_ops = list(self._pending_operation_callbacks.items()) - for pending_op in pending_ops: - mid = pending_op[0] - del self._pending_operation_callbacks[mid] - - # Clear unknown responses - unknown_mids = [mid for mid in self._unknown_operation_completions] - for mid in unknown_mids: - del self._unknown_operation_completions[mid] - - # Trigger cancel in pending operation callbacks - for pending_op in pending_ops: - mid = pending_op[0] - callback = pending_op[1] - if callback: - logger.debug("Cancelling {} - Triggering callback".format(mid)) - try: - callback(cancelled=True) - except Exception: - logger.debug("Unexpected error calling callback for MID: {}".format(mid)) - logger.debug(traceback.format_exc()) - else: - logger.debug("Cancelling {} - No callback set for MID".format(mid)) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/__init__.py b/azure-iot-device/azure/iot/device/common/pipeline/__init__.py deleted file mode 100644 index 0a9b5b3ae..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Azure IoT Hub Device SDK Pipeline - -This package provides pipeline objects for use with the Azure IoT Hub Device SDK. - -INTERNAL USAGE ONLY -""" -from .pipeline_events_base import PipelineEvent # noqa: F401 -from .pipeline_ops_base import PipelineOperation # noqa: F401 -from .pipeline_stages_base import PipelineStage # noqa: F401 -from .pipeline_exceptions import OperationCancelled # noqa: F401 diff --git a/azure-iot-device/azure/iot/device/common/pipeline/config.py b/azure-iot-device/azure/iot/device/common/pipeline/config.py deleted file mode 100644 index 81a507b6e..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/config.py +++ /dev/null @@ -1,132 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import threading -import abc -from azure.iot.device import constant - - -logger = logging.getLogger(__name__) - -DEFAULT_KEEPALIVE = 60 - - -class BasePipelineConfig(abc.ABC): - """A base class for storing all configurations/options shared across the Azure IoT Python Device Client Library. - More specific configurations such as those that only apply to the IoT Hub Client will be found in the respective - config files. - """ - - def __init__( - self, - hostname, - gateway_hostname=None, - sastoken=None, - x509=None, - server_verification_cert=None, - websockets=False, - cipher="", - proxy_options=None, - keep_alive=DEFAULT_KEEPALIVE, - auto_connect=True, - connection_retry=True, - connection_retry_interval=10, - ): - """Initializer for BasePipelineConfig - - :param str hostname: The hostname being connected to - :param str gateway_hostname: The gateway hostname optionally being used - :param sastoken: SasToken to be used for authentication. Mutually exclusive with x509. - :type sastoken: :class:`azure.iot.device.common.auth.SasToken` - :param x509: X509 to be used for authentication. Mutually exclusive with sastoken. - :type x509: :class:`azure.iot.device.models.X509` - :param str server_verification_cert: The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param bool websockets: Enabling/disabling websockets in MQTT. This feature is relevant - if a firewall blocks port 8883 from use. - :param cipher: Optional cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param proxy_options: Details of proxy configuration - :type proxy_options: :class:`azure.iot.device.common.models.ProxyOptions` - :param int keepalive: Maximum period in seconds between communications with the - broker. - :param bool auto_connect: Indicates if automatic connects should occur - :param bool connection_retry: Indicates if dropped connection should result in attempts to - re-establish it - :param int connection_retry_interval: Interval (in seconds) between connection retries - """ - # Network - self.hostname = hostname - self.gateway_hostname = gateway_hostname - self.keep_alive = self._sanitize_keep_alive(keep_alive) - - # Auth - self.sastoken = sastoken - self.x509 = x509 - if (not sastoken and not x509) or (sastoken and x509): - raise ValueError("One of either 'sastoken' or 'x509' must be provided") - self.server_verification_cert = server_verification_cert - self.websockets = websockets - self.cipher = self._sanitize_cipher(cipher) - self.proxy_options = proxy_options - - # Pipeline - self.auto_connect = auto_connect - self.connection_retry = connection_retry - self.connection_retry_interval = self._sanitize_connection_retry_interval( - connection_retry_interval - ) - - @staticmethod - def _sanitize_cipher(cipher): - """Sanitize the cipher input and convert to a string in OpenSSL list format""" - if isinstance(cipher, list): - cipher = ":".join(cipher) - - if isinstance(cipher, str): - cipher = cipher.upper() - cipher = cipher.replace("_", "-") - else: - raise TypeError("Invalid type for 'cipher'") - - return cipher - - @staticmethod - def _sanitize_keep_alive(keep_alive): - try: - keep_alive = int(keep_alive) - except (ValueError, TypeError): - raise TypeError("Invalid type for 'keep alive'. Must be a numeric value.") - - if keep_alive <= 0: - # Not allowing a keep alive of 0 as this would mean frequent ping exchanges. - raise ValueError("'keep alive' must be greater than 0") - - if keep_alive > constant.MAX_KEEP_ALIVE_SECS: - raise ValueError("'keep_alive' cannot exceed 1740 seconds (29 minutes)") - - return keep_alive - - @staticmethod - def _sanitize_connection_retry_interval(connection_retry_interval): - try: - connection_retry_interval = int(connection_retry_interval) - except (ValueError, TypeError): - raise TypeError("Invalid type for 'connection_retry_interval'. Must be a numeric value") - - if connection_retry_interval > threading.TIMEOUT_MAX: - # Python timers have a (platform dependent) max timeout. - raise ValueError( - "'connection_retry_interval' cannot exceed {} seconds".format(threading.TIMEOUT_MAX) - ) - - if connection_retry_interval <= 0: - raise ValueError("'connection_retry_interval' must be greater than 0") - - return connection_retry_interval diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_base.py deleted file mode 100644 index 7de1088d3..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_base.py +++ /dev/null @@ -1,99 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - - -class PipelineEvent(object): - """ - A base class for data objects representing events that travels up the pipeline. - - PipelineEvent objects are used for anything that happens inside the pipeline that - cannot be attributed to a specific operation, such as a spontaneous disconnect. - - PipelineEvents flow up the pipeline until they reach the client. Every stage - has the opportunity to handle a given event. If they don't handle it, they - should pass it up to the next stage (this is the default behavior). Stages - have the opportunity to tie a PipelineEvent to a PipelineOperation object - if they are waiting for a response for that particular operation. - - :ivar name: The name of the event. This is used primarily for logging - :type name: str - """ - - def __init__(self): - """ - Initializer for PipelineEvent objects. - """ - if self.__class__ == PipelineEvent: - raise TypeError( - "Cannot instantiate PipelineEvent object. You need to use a derived class" - ) - self.name = self.__class__.__name__ - - -class ResponseEvent(PipelineEvent): - """ - A PipelineEvent object which is the second part of an RequestAndResponseOperation operation - (the response). The RequestAndResponseOperation represents the common operation of sending - a request to iothub with a request_id ($rid) value and waiting for a response with - the same $rid value. This convention is used by both Twin and Provisioning features. - - The response represented by this event has not yet been matched to the corresponding - RequestOperation operation. That matching is done by the CoordinateRequestAndResponseStage - stage which takes the contents of this event and puts it into the RequestAndResponseOperation - operation with the matching $rid value. - - :ivar request_id: The request ID which will eventually be used to match a RequestOperation - operation to this event. - :type request_id: str - :ivar status_code: The status code returned by the response. Any value under 300 is - considered success. - :type status_code: int - :ivar response_body: The body of the response. - :type response_body: str - :ivar retry_after: A retry interval value that was extracted from the topic. - :type retry_after: int - """ - - def __init__(self, request_id, status_code, response_body, retry_after=None): - super().__init__() - self.request_id = request_id - self.status_code = status_code - self.response_body = response_body - self.retry_after = retry_after - - -class ConnectedEvent(PipelineEvent): - """ - A PipelineEvent object indicating a connection has been established. - """ - - pass - - -class DisconnectedEvent(PipelineEvent): - """ - A PipelineEvent object indicating a connection has been dropped. - """ - - pass - - -class NewSasTokenRequiredEvent(PipelineEvent): - """ - A PipelineEvent object indicating that a new SasToken must be provided. - """ - - pass - - -class BackgroundExceptionEvent(PipelineEvent): - """ - An exception was raised in a background thread - """ - - def __init__(self, e): - super().__init__() - self.e = e diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_mqtt.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_mqtt.py deleted file mode 100644 index f92f0ff0f..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_mqtt.py +++ /dev/null @@ -1,23 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from . import PipelineEvent - - -class IncomingMQTTMessageEvent(PipelineEvent): - """ - A PipelineEvent object which represents an incoming MQTT message on some MQTT topic - """ - - def __init__(self, topic, payload): - """ - Initializer for IncomingMQTTMessageEvent objects. - - :param str topic: The name of the topic that the incoming message arrived on. - :param str payload: The payload of the message - """ - super().__init__() - self.topic = topic - self.payload = payload diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_exceptions.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_exceptions.py deleted file mode 100644 index 51ed4e941..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_exceptions.py +++ /dev/null @@ -1,42 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module defines exceptions that may be raised from a pipeline""" - - -class PipelineException(Exception): - """Generic pipeline exception""" - - pass - - -class OperationCancelled(PipelineException): - """Operation was cancelled""" - - pass - - -class OperationTimeout(PipelineException): - """Pipeline operation timed out""" - - pass - - -class OperationError(PipelineException): - """Error while executing an Operation""" - - pass - - -class PipelineNotRunning(PipelineException): - """Pipeline is not currently running""" - - pass - - -class PipelineRuntimeError(PipelineException): - """Error at runtime caused by incorrect pipeline configuration""" - - pass diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_nucleus.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_nucleus.py deleted file mode 100644 index e572cc275..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_nucleus.py +++ /dev/null @@ -1,29 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from enum import Enum - - -class PipelineNucleus(object): - """Contains data and information shared across the pipeline""" - - def __init__(self, pipeline_configuration): - self.pipeline_configuration = pipeline_configuration - self.connection_state = ConnectionState.DISCONNECTED - - @property - def connected(self): - # Only return True if fully connected - return self.connection_state is ConnectionState.CONNECTED - - -class ConnectionState(Enum): - CONNECTED = "CONNECTED" # Client is connected (as far as it knows) - DISCONNECTED = "DISCONNECTED" # Client is disconnected - CONNECTING = "CONNECTING" # Client is in the process of connecting - DISCONNECTING = "DISCONNECTING" # Client is in the process of disconnecting - REAUTHORIZING = "REAUTHORIZING" # Client is in the process of reauthorizing - # NOTE: Reauthorizing is the process of doing a disconnect, then a connect at transport level diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py deleted file mode 100644 index 50557af32..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py +++ /dev/null @@ -1,421 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -from . import pipeline_exceptions -from . import pipeline_thread - -logger = logging.getLogger(__name__) - - -class PipelineOperation(object): - """ - A base class for data objects representing operations that travels down the pipeline. - - Each PipelineOperation object represents a single asynchronous operation that is performed - by the pipeline. The PipelineOperation objects travel through "stages" of the pipeline, - and each stage has the opportunity to act on each specific operation that it - receives. If a stage does not handle a particular operation, it needs to pass it to the - next stage. If the operation gets to the end of the pipeline without being handled - (completed), then it is treated as an error. - - :ivar name: The name of the operation. This is used primarily for logging - :type name: str - :ivar callback: The callback that is called when the operation is completed, either - successfully or with a failure. - :type callback: Function - :ivar needs_connection: This is an attribute that indicates whether a particular operation - requires a connection to operate. This is currently used by the AutoConnectStage - stage, but this functionality will be revamped shortly. - :type needs_connection: Boolean - :ivar error: The presence of a value in the error attribute indicates that the operation failed, - absence of this value indicates that the operation either succeeded or hasn't been handled yet. - :type error: Error - """ - - def __init__(self, callback): - """ - Initializer for PipelineOperation objects. - - :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. - """ - if self.__class__ == PipelineOperation: - raise TypeError( - "Cannot instantiate PipelineOperation object. You need to use a derived class" - ) - self.name = self.__class__.__name__ - self.callback_stack = [] - self.needs_connection = False - self.completed = False # Operation has been fully completed - self.completing = False # Operation is in the process of completing - self.error = None # Error associated with Operation completion - - self.add_callback(callback) - - def add_callback(self, callback): - """Adds a callback to the Operation that will be triggered upon Operation completion. - - When an Operation is completed, all callbacks will be resolved in LIFO order. - - Callbacks cannot be added to an already completed operation, or an operation that is - currently undergoing a completion process. - - :param callback: The callback to add to the operation. - - :raises: OperationError if the operation is already completed, or is in the process of - completing. - """ - if self.completed: - raise pipeline_exceptions.OperationError( - "{}: Attempting to add a callback to an already-completed operation!".format( - self.name - ) - ) - if self.completing: - raise pipeline_exceptions.OperationError( - "{}: Attempting to add a callback to a operation with completion in progress!".format( - self.name - ) - ) - else: - self.callback_stack.append(callback) - - @pipeline_thread.runs_on_pipeline_thread - def complete(self, error=None): - """Complete the operation, and trigger all callbacks in LIFO order. - - The operation is completed successfully be default, or completed unsuccessfully if an error - is provided. - - An operation that is already fully completed, or in the process of completion cannot be - completed again. - - This process can be halted if a callback for the operation invokes the .halt_completion() - method on this Operation. - - Note that if an error is raised the operation may not be able to be completed. - - :param error: Optionally provide an Exception object indicating the error that caused - the completion. Providing an error indicates that the operation was unsuccessful. - - :raises: OperationError if the operation cannot properly reach a completed state - :raises: OperationError if an error occurs in resolving any callbacks - """ - if error: - logger.debug("{}: completing with error {}".format(self.name, error)) - else: - logger.debug("{}: completing without error".format(self.name)) - - if self.completed or self.completing: - raise pipeline_exceptions.OperationError( - "Attempting to complete an already-completed operation: {}".format(self.name) - ) - else: - # Operation is now in the process of completing - self.completing = True - self.error = error - - while self.callback_stack: - if not self.completing: - logger.debug("{}: Completion halted!".format(self.name)) - break - if self.completed: - # This block should never be reached - this is an invalid state. - # If this block is reached, there is a bug in the code. - self.halt_completion() - raise pipeline_exceptions.OperationError( - "Operation reached fully completed state while still resolving completion: {}".format( - self.name - ) - ) - - callback = self.callback_stack.pop() - try: - callback(op=self, error=error) - except Exception as e: - logger.warning( - "Unhandled error while triggering callback for {}".format(self.name) - ) - self.halt_completion() - raise pipeline_exceptions.OperationError( - "Exception occurred while triggering completion callback" - ) from e - - if self.completing: - # Operation is now completed, no longer in the process of completing - self.completing = False - self.completed = True - - @pipeline_thread.runs_on_pipeline_thread - def halt_completion(self): - """Halt the completion of an operation that is currently undergoing a completion process - as a result of a call to .complete(). - - Completion cannot be halted if there is no currently ongoing completion process. The only - way to successfully invoke this method is from within a callback on the Operation in - question. - - This method will leave any yet-untriggered callbacks on the Operation to be triggered upon - a later completion. - - This method will clear any error associated with the currently ongoing completion process - from the Operation. - """ - if not self.completing: - raise pipeline_exceptions.OperationError( - "Attempting to halt completion of an operation not in the process of completion: {}".format( - self.name - ) - ) - else: - self.completing = False - self.error = None - logger.debug("{}: Operation completion halted".format(self.name)) - - @pipeline_thread.runs_on_pipeline_thread - def spawn_worker_op(self, worker_op_type, **kwargs): - """Create and return a new operation, which, when completed, will complete the operation - it was spawned from. - - :param worker_op_type: The type (class) of the new worker operation. - :param **kwargs: The arguments to instantiate the new worker operation with. Note that a - callback is not required, but if provided, will be triggered prior to completing the - operation that spawned the worker operation. - - :returns: A new worker operation of the type specified in the worker_op_type parameter. - """ - logger.debug("{}: creating worker op of type {}".format(self.name, worker_op_type.__name__)) - - @pipeline_thread.runs_on_pipeline_thread - def on_worker_op_complete(op, error): - logger.debug("{}: Worker op ({}) has been completed".format(self.name, op.name)) - self.complete(error=error) - - if "callback" in kwargs: - provided_callback = kwargs["callback"] - kwargs["callback"] = on_worker_op_complete - worker_op = worker_op_type(**kwargs) - worker_op.add_callback(provided_callback) - else: - kwargs["callback"] = on_worker_op_complete - worker_op = worker_op_type(**kwargs) - - return worker_op - - -class InitializePipelineOperation(PipelineOperation): - """ - A PipelineOperation for doing initial setup of the pipeline - - Attributes can be dynamically added to this operation for use in other stages if necessary - (e.g. initialization requires a derived value) - """ - - pass - - -class ShutdownPipelineOperation(PipelineOperation): - """ - A PipelineOperation for doing teardown of the pipeline. - """ - - pass - - -class ConnectOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to connect to whatever service it needs to connect to. - - This operation is in the group of base operations because connecting is a common operation that many clients might need to do. - - Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). - """ - - def __init__(self, callback): - self.watchdog_timer = None - super().__init__(callback) - - -class ReauthorizeConnectionOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to reauthorize the connection to whatever service it is connected to. - - Clients will most-likely submit a ReauthorizeConnectionOperation when some credential (such as a sas token) has changed and the protocol client - needs to re-establish the connection to refresh the credentials - - This operation is in the group of base operations because reauthorizing is a common operation that many clients might need to do. - - Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). - """ - - pass - - -class DisconnectOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to disconnect from whatever service it might be connected to. - - This operation is in the group of base operations because disconnecting is a common operation that many clients might need to do. - - Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). - """ - - def __init__(self, callback): - self.hard = True # Indicates if this is a "hard" disconnect that kills in-flight ops - super().__init__(callback) - - -class EnableFeatureOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to "enable" a particular feature. - - A "feature" is just a string which represents some set of functionality that needs to be enabled, such as "C2D" or "Twin". - - This object has no notion of what it means to "enable" a feature. That knowledge is handled by stages in the pipeline which might convert - this operation to a more specific operation (such as an MQTT subscribe operation with a specific topic name). - - This operation is in the group of base operations because disconnecting is a common operation that many clients might need to do. - - Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). - """ - - def __init__(self, feature_name, callback): - """ - Initializer for EnableFeatureOperation objects. - - :param str feature_name: Name of the feature that is being enabled. The meaning of this - string is defined in the stage which handles this operation. - :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. - """ - super().__init__(callback=callback) - self.feature_name = feature_name - - -class DisableFeatureOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to "disable" a particular feature. - - A "feature" is just a string which represents some set of functionality that needs to be disabled, such as "C2D" or "Twin". - - This object has no notion of what it means to "disable" a feature. That knowledge is handled by stages in the pipeline which might convert - this operation to a more specific operation (such as an MQTT unsubscribe operation with a specific topic name). - - This operation is in the group of base operations because disconnecting is a common operation that many clients might need to do. - - Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). - """ - - def __init__(self, feature_name, callback): - """ - Initializer for DisableFeatureOperation objects. - - :param str feature_name: Name of the feature that is being disabled. The meaning of this - string is defined in the stage which handles this operation. - :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. - """ - super().__init__(callback=callback) - self.feature_name = feature_name - - -class RequestAndResponseOperation(PipelineOperation): - """ - A PipelineOperation object which wraps the common operation of sending a request to iothub with a request_id ($rid) - value and waiting for a response with the same $rid value. This convention is used by both Twin and Provisioning - features. - - Even though this is an base operation, it will most likely be generated and also handled by more specifics stages - (such as IoTHub or MQTT stages). - - The type of the request payload and the response payload is undefined at this level. The type of the payload is defined - based on the type of request that is being executed. If types need to be converted, that is the responsibility of - the stage which creates this operation, and also the stage which executes on the operation. - - :ivar status_code: The status code returned by the response. Any value under 300 is considered success. - :type status_code: int - :ivar response_body: The body of the response. - :type response_body: Undefined - :ivar query_params: Any query parameters that need to be sent with the request. - Example is the id of the operation as returned by the initial provisioning request. - """ - - def __init__( - self, request_type, method, resource_location, request_body, callback, query_params=None - ): - """ - Initializer for RequestAndResponseOperation objects - - :param str request_type: The type of request. This is a string which is used by protocol-specific stages to - generate the actual request. For example, if request_type is "twin", then the iothub_mqtt stage will convert - the request into an MQTT publish with topic that begins with $iothub/twin - :param str method: The method for the request, in the REST sense of the word, such as "POST", "GET", etc. - :param str resource_location: The resource that the method is acting on, in the REST sense of the word. - For twin request with method "GET", this is most likely the string "/" which retrieves the entire twin - :param request_body: The body of the request. This is a required field, and a single space can be used to denote - an empty body. - :type request_body: Undefined - :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. - """ - super().__init__(callback=callback) - self.request_type = request_type - self.method = method - self.resource_location = resource_location - self.request_body = request_body - self.status_code = None - self.response_body = None - self.query_params = query_params - - -class RequestOperation(PipelineOperation): - """ - A PipelineOperation object which is the first part of an RequestAndResponseOperation operation (the request). The second - part of the RequestAndResponseOperation operation (the response) is returned via an ResponseEvent event. - - Even though this is an base operation, it will most likely be generated and also handled by more specifics stages - (such as IoTHub or MQTT stages). - """ - - def __init__( - self, - request_type, - method, - resource_location, - request_body, - request_id, - callback, - query_params=None, - ): - """ - Initializer for RequestOperation objects - - :param str request_type: The type of request. This is a string which is used by protocol-specific stages to - generate the actual request. For example, if request_type is "twin", then the iothub_mqtt stage will convert - the request into an MQTT publish with topic that begins with $iothub/twin - :param str method: The method for the request, in the REST sense of the word, such as "POST", "GET", etc. - :param str resource_location: The resource that the method is acting on, in the REST sense of the word. - For twin request with method "GET", this is most likely the string "/" which retrieves the entire twin - :param request_body: The body of the request. This is a required field, and a single space can be used to denote - an empty body. - :type request_body: dict, str, int, float, bool, or None (JSON compatible values) - :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. - :type query_params: Any query parameters that need to be sent with the request. - Example is the id of the operation as returned by the initial provisioning request. - """ - super().__init__(callback=callback) - self.method = method - self.resource_location = resource_location - self.request_type = request_type - self.request_body = request_body - self.request_id = request_id - self.query_params = query_params diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py deleted file mode 100644 index cc528bc1c..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py +++ /dev/null @@ -1,36 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from . import PipelineOperation - - -class HTTPRequestAndResponseOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to connect to a server using the HTTP protocol. - - This operation is in the group of HTTP operations because its attributes are very specific to the HTTP protocol. - """ - - def __init__(self, method, path, headers, body, query_params, callback): - """ - Initializer for HTTPPublishOperation objects. - :param str method: The HTTP method used in the request - :param str path: The path to be used in the request url - :param dict headers: The headers to be used in the HTTP request - :param str body: The body to be provided with the HTTP request - :param str query_params: The query parameters to be used in the request url - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.method = method - self.path = path - self.headers = headers - self.body = body - self.query_params = query_params - self.status_code = None - self.response_body = None - self.reason = None diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py deleted file mode 100644 index 9c47336c9..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py +++ /dev/null @@ -1,76 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from . import PipelineOperation - - -class MQTTPublishOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to publish a specific payload on a specific topic using the MQTT protocol. - - This operation is in the group of MQTT operations because its attributes are very specific to the MQTT protocol. - """ - - def __init__(self, topic, payload, callback): - """ - Initializer for MQTTPublishOperation objects. - - :param str topic: The name of the topic to publish to - :param str payload: The payload to publish - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.topic = topic - self.payload = payload - self.needs_connection = True - self.retry_timer = None - - -class MQTTSubscribeOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to subscribe to a specific MQTT topic using the MQTT protocol. - - This operation is in the group of MQTT operations because its attributes are very specific to the MQTT protocol. - """ - - def __init__(self, topic, callback): - """ - Initializer for MQTTSubscribeOperation objects. - - :param str topic: The name of the topic to subscribe to - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.topic = topic - self.needs_connection = True - self.timeout_timer = None - self.retry_timer = None - - -class MQTTUnsubscribeOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to unsubscribe from a specific MQTT topic using the MQTT protocol. - - This operation is in the group of MQTT operations because its attributes are very specific to the MQTT protocol. - """ - - def __init__(self, topic, callback): - """ - Initializer for MQTTUnsubscribeOperation objects. - - :param str topic: The name of the topic to unsubscribe from - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.topic = topic - self.needs_connection = True - self.timeout_timer = None - self.retry_timer = None diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py deleted file mode 100644 index 82874c0f3..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py +++ /dev/null @@ -1,1369 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import abc -import time -import traceback -import uuid -import weakref -import threading -import queue -from . import pipeline_events_base -from . import pipeline_ops_base, pipeline_ops_mqtt -from . import pipeline_thread -from . import pipeline_exceptions -from .pipeline_nucleus import ConnectionState -from azure.iot.device.common import transport_exceptions, alarm -from azure.iot.device.common.auth import sastoken as st - -logger = logging.getLogger(__name__) - - -class PipelineStage(abc.ABC): - """ - Base class representing a stage in the processing pipeline. Each stage is responsible for receiving - PipelineOperation objects from the top, possibly processing them, and possibly passing them down. It - is also responsible for receiving PipelineEvent objects from the bottom, possibly processing them, and - possibly passing them up. - - Each PipelineStage in the pipeline, is expected to act on some well-defined set of PipelineOperation - types and/or some set of PipelineEvent types. If any stage does not act on an operation or event, it - should pass it to the next stage (for operations) or the previous stage (for events). In this way, the - pipeline implements the "chain of responsibility" design pattern (Gamma, et.al. "Design Patterns". - Addison Wesley. 1995), with each stage being responsible for implementing some "rule" or "policy" of the - pipeline, and each stage being ignorant of the stages that are before or after it in the pipeline. - - Each stage in the pipeline should act on the smallest set of rules possible, thus making stages small - and easily testable. Complex logic should be the exception and not the rule, and complex stages should - operate on the most generic type of operation possible, thus allowing us to re-use complex logic for - multiple cases. The best way to do this is with "converter" stages that convert a specific operation to - a more general one and with other converter stages that convert general operations to more specific ones. - - An example of a specific-to-generic stage is UseSkAuthProviderStage which takes a specific operation - (use an auth provider) and converts it into something more generic (here is your device_id, etc, and use - this SAS token when connecting). - - An example of a generic-to-specific stage is IoTHubMQTTTranslationStage which converts IoTHub operations - (such as SendD2CMessageOperation) to MQTT operations (such as Publish). - - Each stage should also work in the broadest domain possible. For example a generic stage (say - "AutoConnectStage") that initiates a connection if any arbitrary operation needs a connection is more useful - than having some MQTT-specific code that re-connects to the MQTT broker if the user calls Publish and - there's no connection. - - One way to think about stages is to look at every "block of functionality" in your code and ask yourself - "is this the one and only time I will need this code"? If the answer is no, it might be worthwhile to - implement that code in it's own stage in a very generic way. - - - :ivar name: The name of the stage. This is used primarily for logging - :type name: str - :ivar next: The next stage in the pipeline. Set to None if this is the last stage in the pipeline. - :type next: PipelineStage - :ivar previous: The previous stage in the pipeline. Set to None if this is the first stage in the pipeline. - :type previous: PipelineStage - :ivar nucleus: The pipeline's "nucleus" which contains global pipeline information, accessible - from all stages - :type nucleus: PipelineNucleus - """ - - def __init__(self): - """ - Initializer for PipelineStage objects. - """ - self.name = self.__class__.__name__ - self.next = None - self.previous = None - self.nucleus = None - - @pipeline_thread.runs_on_pipeline_thread - def run_op(self, op): - """ - Run the given operation. This is the public function that outside callers would call to run an - operation. Derived classes should override the private _run_op function to implement - stage-specific behavior. When run_op returns, that doesn't mean that the operation has executed - to completion. Rather, it means that the pipeline has done something that will cause the - operation to eventually execute to completion. That might mean that something was sent over - the network and some stage is waiting for a reply, or it might mean that the operation is sitting - in a queue until something happens, or it could mean something entirely different. The only - thing you can assume is that the operation will _eventually_ complete successfully or fail, and the - operation's callback will be called when that happens. - - :param PipelineOperation op: The operation to run. - """ - try: - self._run_op(op) - except Exception as e: - # This path is ONLY for unexpected errors. Expected errors should cause a fail completion - # within ._run_op(). - # - # We tag errors from here as logger.warning because, while we return them to the - # caller and rely on the caller to handle them, they're somewhat unexpected and might be - # worthy of investigation. - - # Do not use exc_info parameter on logger.* calls. This causes pytest to save the - # traceback which saves stack frames which shows up as a leak - logger.warning(msg="Unexpected error in {}._run_op() call".format(self)) - logger.warning(traceback.format_exc()) - - # Only complete the operation if it is not already completed. - # Attempting to complete a completed operation would raise an exception. - if not op.completed: - op.complete(error=e) - else: - # Note that this would be very unlikely to occur. It could only happen if a stage - # was doing something after completing an operation, and an exception was raised, - # which is unlikely because stages usually don't do anything after completing an - # operation. - raise e - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - """ - Implementation of the stage-specific function of .run_op(). Override this method instead of - .run_op() in child classes in order to change how a stage behaves when running an operation. - - See the description of the .run_op() method for more discussion on what it means to "run" - an operation. - - :param PipelineOperation op: The operation to run. - """ - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def handle_pipeline_event(self, event): - """ - Handle a pipeline event that arrives from the stage below this stage. Derived - classes should not override this function. Any stage-specific handling of - PipelineEvent objects should be implemented by overriding the private - _handle_pipeline_event function in the derived stage. - - :param PipelineEvent event: The event that is being passed back up the pipeline - """ - try: - self._handle_pipeline_event(event) - except Exception as e: - # Do not use exc_info parameter on logger.* calls. This causes pytest to save the - # traceback which saves stack frames which shows up as a leak - logger.warning( - msg="{}: Unexpected error in ._handle_pipeline_event() call: {}".format(self, e) - ) - if self.previous: - logger.warning("{}: Raising background exception") - self.report_background_exception(e) - else: - # Nothing else we can do but log this. There exists no stage we can send the - # exception to, and raising would send the error back down the pipeline. - logger.warning( - "{}: Cannot report a background exception because there is no previous stage!" - ) - - @pipeline_thread.runs_on_pipeline_thread - def _handle_pipeline_event(self, event): - """ - Handle a pipeline event that arrives from the stage below this stage. This - is a function that is intended to be overridden in any stages that want to implement - stage-specific handling of any events - - :param PipelineEvent event: The event that is being passed back up the pipeline - """ - self.send_event_up(event) - - @pipeline_thread.runs_on_pipeline_thread - def send_op_down(self, op): - """ - Helper function to continue a given operation by passing it to the next stage - in the pipeline. If there is no next stage in the pipeline, this function - will fail the operation and call complete_op to return the failure back up the - pipeline. - - :param PipelineOperation op: Operation which is being passed on - """ - if self.next: - self.next.run_op(op) - else: - # This shouldn't happen if the pipeline was created correctly - logger.warning( - "{}({}): no next stage.cannot send op down. completing with error".format( - self.name, op.name - ) - ) - raise pipeline_exceptions.PipelineRuntimeError( - "{} not handled after {} stage with no next stage".format(op.name, self.name) - ) - - @pipeline_thread.runs_on_pipeline_thread - def send_event_up(self, event): - """ - Helper function to pass an event to the previous stage of the pipeline. This is the default - behavior of events while traveling through the pipeline. They start somewhere (maybe the - bottom) and move up the pipeline until they're handled or until they error out. - """ - if self.previous: - self.previous.handle_pipeline_event(event) - else: - # This shouldn't happen if the pipeline was created correctly - logger.critical( - "{}({}): no previous stage. cannot send event up".format(event.name, self.name) - ) - # NOTE: We can't report a background exception here because that involves - # sending an event up, which is what got us into this problem in the first place. - # Instead, raise, and let the method invoking this method handle it - raise pipeline_exceptions.PipelineRuntimeError( - "{} not handled after {} stage with no previous stage".format(event.name, self.name) - ) - - @pipeline_thread.runs_on_pipeline_thread - def report_background_exception(self, e): - """ - Send an exception up the pipeline that occurred in the background. - These would typically be in response to unsolicited actions, such as receiving data or - timer-based operations, which cannot be raised to the user because they occurred on a - non-application thread. - - Note that this function leverages pipeline event flow, which means that any background - exceptions in the core event flow itself become problematic (it's a good thing it's well - tested then!) - - :param Exception e: The exception that occurred in the background - """ - event = pipeline_events_base.BackgroundExceptionEvent(e) - self.send_event_up(event) - - -class PipelineRootStage(PipelineStage): - """ - Object representing the root of a pipeline. This is where the functions to build - the pipeline exist. This is also where clients can add event handlers to receive - events from the pipeline. - - :ivar on_pipeline_event_handler: Handler which can be set by users of the pipeline to - receive PipelineEvent objects. This is how users receive any "unsolicited" - events from the pipeline (such as C2D messages). This function is called with - a PipelineEvent object every time any such event occurs. - :type on_pipeline_event_handler: Function - :ivar on_connected_handler: Handler which can be set by users of the pipeline to - receive events every time the underlying transport connects - :type on_connected_handler: Function - :ivar on_disconnected_handler: Handler which can be set by users of the pipeline to - receive events every time the underlying transport disconnects - :type on_disconnected_handler: Function - """ - - def __init__(self, nucleus): - super().__init__() - self.on_pipeline_event_handler = None - self.on_connected_handler = None - self.on_disconnected_handler = None - self.on_new_sastoken_required_handler = None - self.on_background_exception_handler = None - self.nucleus = nucleus - - def run_op(self, op): - # CT-TODO: make this more elegant - op.callback_stack[0] = pipeline_thread.invoke_on_callback_thread_nowait( - op.callback_stack[0] - ) - pipeline_thread.invoke_on_pipeline_thread(super().run_op)(op) - - def append_stage(self, new_stage): - """ - Add the next stage to the end of the pipeline. This is the function that callers - use to build the pipeline by appending stages. This function returns the root of - the pipeline so that calls to this function can be chained together. - - :param PipelineStage new_stage: Stage to add to the end of the pipeline - :returns: The root of the pipeline. - """ - old_tail = self - while old_tail.next: - old_tail = old_tail.next - old_tail.next = new_stage - new_stage.previous = old_tail - new_stage.nucleus = self.nucleus - return self - - @pipeline_thread.runs_on_pipeline_thread - def _handle_pipeline_event(self, event): - """ - Override of the PipelineEvent handler. Because this is the root of the pipeline, - this function calls the on_pipeline_event_handler to pass the event to the - caller. - - :param PipelineEvent event: Event to be handled, i.e. returned to the caller - through the handle_pipeline_event (if provided). - """ - # Base events that are common to all pipelines are handled here - if isinstance(event, pipeline_events_base.ConnectedEvent): - logger.debug( - "{}: ConnectedEvent received. Calling on_connected_handler".format(self.name) - ) - - if self.on_connected_handler: - pipeline_thread.invoke_on_callback_thread_nowait(self.on_connected_handler)() - - elif isinstance(event, pipeline_events_base.DisconnectedEvent): - logger.debug( - "{}: DisconnectedEvent received. Calling on_disconnected_handler".format(self.name) - ) - if self.on_disconnected_handler: - pipeline_thread.invoke_on_callback_thread_nowait(self.on_disconnected_handler)() - - elif isinstance(event, pipeline_events_base.NewSasTokenRequiredEvent): - logger.debug( - "{}: NewSasTokenRequiredEvent received. Calling on_new_sastoken_required_handler".format( - self.name - ) - ) - if self.on_new_sastoken_required_handler: - pipeline_thread.invoke_on_callback_thread_nowait( - self.on_new_sastoken_required_handler - )() - - elif isinstance(event, pipeline_events_base.BackgroundExceptionEvent): - logger.debug( - "{}: BackgroundExceptionEvent received. Calling on_background_exception_handler".format( - self.name - ) - ) - if self.on_background_exception_handler: - pipeline_thread.invoke_on_callback_thread_nowait( - self.on_background_exception_handler - )(event.e) - - # Events that are domain-specific and unique to each pipeline are handled by the provided - # domain-specific .on_pipeline_event_handler - else: - if self.on_pipeline_event_handler: - pipeline_thread.invoke_on_callback_thread_nowait(self.on_pipeline_event_handler)( - event - ) - else: - # unexpected condition: we should be handling all pipeline events - logger.debug("incoming {} event with no handler. dropping.".format(event.name)) - - -# NOTE: This stage could be a candidate for being refactored into some kind of other -# pipeline-related structure. What's odd about it as a stage is that it doesn't really respond -# to operations or events so much as it spawns them on a timer. -# Perhaps some kind of... Pipeline Daemon? -class SasTokenStage(PipelineStage): - # Amount of time, in seconds, prior to token expiration to trigger alarm - DEFAULT_TOKEN_UPDATE_MARGIN = 120 - - def __init__(self): - super().__init__() - # Indicates when token needs to be updated - self._token_update_alarm = None - # Indicates when to retry a failed reauthorization attempt - # (only used with renewable SAS auth) - self._reauth_retry_timer = None - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if ( - isinstance(op, pipeline_ops_base.InitializePipelineOperation) - and self.nucleus.pipeline_configuration.sastoken is not None - ): - # Start an alarm (renewal or replacement depending on token type) - self._start_token_update_alarm() - self.send_op_down(op) - elif ( - isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation) - and self.nucleus.pipeline_configuration.sastoken is not None - ): - # NOTE 1: This case (currently) implies that we are using Non-Renewable SAS, - # although it's not enforced here (it's a product of how the pipeline and client are - # configured overall) - - # NOTE 2: There's a theoretically possible case where the new token has the same expiry - # time as the old token, and thus a new update alarm wouldn't really be required, but - # I don't want to include the complexity of checking. Just start a new alarm anyway. - - # NOTE 3: Yeah, this is the same logic as the above case for the InitializePipeline op, - # but if it weren't separate, how would you get all these nice informative comments? - # (Also, it leaves room for the logic to change in the future) - self._start_token_update_alarm() - self.send_op_down(op) - elif isinstance(op, pipeline_ops_base.ShutdownPipelineOperation): - self._cancel_token_update_alarm() - self._cancel_reauth_retry_timer() - self.send_op_down(op) - else: - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _cancel_token_update_alarm(self): - """Cancel and delete any pending update alarm""" - old_alarm = self._token_update_alarm - self._token_update_alarm = None - if old_alarm: - logger.debug("Cancelling SAS Token update alarm") - old_alarm.cancel() - old_alarm = None - - @pipeline_thread.runs_on_pipeline_thread - def _cancel_reauth_retry_timer(self): - """Cancel and delete any pending reauth retry timer""" - old_reauth_retry_timer = self._reauth_retry_timer - self._reauth_retry_timer = None - if old_reauth_retry_timer: - logger.debug("Cancelling reauthorization retry timer") - old_reauth_retry_timer.cancel() - old_reauth_retry_timer = None - - @pipeline_thread.runs_on_pipeline_thread - def _start_token_update_alarm(self): - """Begin an update alarm. - If using a RenewableSasToken, when the alarm expires the token will be automatically - renewed, and a new alarm will be set. - - If using a NonRenewableSasToken, when the alarm expires, it will trigger a - NewSasTokenRequiredEvent to signal that a new SasToken must be manually provided. - """ - self._cancel_token_update_alarm() - - update_time = ( - self.nucleus.pipeline_configuration.sastoken.expiry_time - - self.DEFAULT_TOKEN_UPDATE_MARGIN - ) - - # On Windows platforms, the threading event TIMEOUT_MAX (approximately 49.7 days) could - # conceivably be less than the SAS lifespan, which means we may need to update the token - # before the lifespan ends. - # If we really wanted to adjust this in the future to use the entire SAS lifespan, we could - # implement Alarms that trigger other Alarms, but for now, just forcing a token update - # is good enough. - # Note that this doesn't apply to (most) Unix platforms, where TIMEOUT_MAX is 292.5 years. - if (update_time - time.time()) > threading.TIMEOUT_MAX: - update_time = time.time() + threading.TIMEOUT_MAX - logger.warning( - "SAS Token expiration ({expiry} seconds) exceeds max scheduled renewal time ({max} seconds). Will be renewing after {max} seconds instead".format( - expiry=self.nucleus.pipeline_configuration.sastoken.expiry_time, - max=threading.TIMEOUT_MAX, - ) - ) - - self_weakref = weakref.ref(self) - - # For renewable SasTokens, create an alarm that will automatically renew the token, - # and then start another alarm. - if isinstance(self.nucleus.pipeline_configuration.sastoken, st.RenewableSasToken): - logger.debug( - "{}: Scheduling automatic SAS Token renewal at epoch time: {}".format( - self.name, update_time - ) - ) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def renew_token(): - this = self_weakref() - # Cancel any token reauth retry timer in progress (from a previous renewal) - this._cancel_reauth_retry_timer() - logger.info("{}: Renewing SAS Token...".format(self.name)) - # Renew the token - sastoken = this.nucleus.pipeline_configuration.sastoken - try: - sastoken.refresh() - except st.SasTokenError as e: - logger.error("{}: SAS Token renewal failed".format(self.name)) - this.report_background_exception(e) - # TODO: then what? How do we respond to this? Retry? - # What if it never works and the token expires? - else: - # If the pipeline is already connected, send order to reauthorize the connection - # now that token has been renewed. If the pipeline is not currently connected, - # there is no need to do this, as the next connection will be using the new - # credentials. - if this.nucleus.connected: - this._reauthorize() - - # Once again, start a renewal alarm - this._start_token_update_alarm() - - self._token_update_alarm = alarm.Alarm(update_time, renew_token) - - # For nonrenewable SasTokens, create an alarm that will issue a NewSasTokenRequiredEvent - else: - logger.debug( - "Scheduling manual SAS Token renewal at epoch time: {}".format(update_time) - ) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def request_new_token(): - this = self_weakref() - logger.info("Requesting new SAS Token....") - # Send request - this.send_event_up(pipeline_events_base.NewSasTokenRequiredEvent()) - - self._token_update_alarm = alarm.Alarm(update_time, request_new_token) - - self._token_update_alarm.daemon = True - self._token_update_alarm.start() - - @pipeline_thread.runs_on_pipeline_thread - def _reauthorize(self): - self_weakref = weakref.ref(self) - - @pipeline_thread.runs_on_pipeline_thread - def on_reauthorize_complete(op, error): - this = self_weakref() - if error: - logger.info( - "{}: Connection reauthorization failed. Error={}".format(this.name, error) - ) - self.report_background_exception(error) - # If connection has not been somehow re-established, we need to keep trying - # because for the reauthorization to originally have been issued, we were in - # a connected state. - # NOTE: we only do this if connection retry is enabled on the pipeline. If it is, - # we have a contract to maintain a connection. If it has been disabled, we have - # a contract to not do so. - # NOTE: We can't rely on the ConnectionStateStage to do this because 1) the pipeline - # stages should stand on their own, and 2) if the reauth failed, the ConnectionStateStage - # wouldn't know to reconnect, because the expected state of a failed reauth is - # to be disconnected. - if ( - not this.nucleus.connected - and this.nucleus.pipeline_configuration.connection_retry - ): - logger.info("{}: Retrying connection reauthorization".format(this.name)) - # No need to cancel the timer, because if this is running, it has already ended - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def retry_reauthorize(): - # We need to check this when the timer expires as well as before creating - # the timer in case connection has been re-established while timer was - # running - if not this.nucleus.connected: - this._reauthorize() - - this._reauth_retry_timer = threading.Timer( - this.nucleus.pipeline_configuration.connection_retry_interval, - retry_reauthorize, - ) - this._reauth_retry_timer.daemon = True - this._reauth_retry_timer.start() - - else: - logger.info("{}: Connection reauthorization successful".format(this.name)) - - logger.info("{}: Starting reauthorization process for new SAS token".format(self.name)) - self.send_op_down( - pipeline_ops_base.ReauthorizeConnectionOperation(callback=on_reauthorize_complete) - ) - - -class AutoConnectStage(PipelineStage): - """ - This stage is responsible for ensuring that the protocol is connected when - it needs to be connected. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - # Any operation that requires a connection can trigger a connection if - # we're not connected and the auto-connect feature is enabled. - if ( - op.needs_connection - and not self.nucleus.connected - and self.nucleus.pipeline_configuration.auto_connect - ): - logger.debug( - "{}({}): Op needs connection. Queueing this op and starting a ConnectionOperation".format( - self.name, op.name - ) - ) - self._do_connect(op) - - else: - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _do_connect(self, op): - """ - Start connecting the transport in response to some operation - """ - # Alias to avoid overload within the callback below - # CT-TODO: remove the need for this with better callback semantics - op_needs_connect = op - - # function that gets called after we're connected. - @pipeline_thread.runs_on_pipeline_thread - def on_connect_op_complete(op, error): - if error: - logger.debug( - "{}({}): Connection failed. Completing with failure because of connection failure: {}".format( - self.name, op_needs_connect.name, error - ) - ) - op_needs_connect.complete(error=error) - else: - logger.debug( - "{}({}): connection is complete. Running op that triggered connection.".format( - self.name, op_needs_connect.name - ) - ) - self.run_op(op_needs_connect) - - # call down to the next stage to connect. - logger.debug("{}({}): calling down with Connect operation".format(self.name, op.name)) - self.send_op_down(pipeline_ops_base.ConnectOperation(callback=on_connect_op_complete)) - - -class CoordinateRequestAndResponseStage(PipelineStage): - """ - Pipeline stage which is responsible for coordinating RequestAndResponseOperation operations. For each - RequestAndResponseOperation operation, this stage passes down a RequestOperation operation and waits for - an ResponseEvent event. All other events are passed down unmodified. - """ - - def __init__(self): - super().__init__() - self.pending_responses = {} - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_base.RequestAndResponseOperation): - # Convert RequestAndResponseOperation operation into a RequestOperation operation - # and send it down. A lower level will convert the RequestOperation into an - # actual protocol client operation. The RequestAndResponseOperation operation will be - # completed when the corresponding IotResponse event is received in this stage. - - request_id = str(uuid.uuid4()) - - logger.debug( - "{}({}): adding request {} to pending list".format(self.name, op.name, request_id) - ) - self.pending_responses[request_id] = op - - self._send_request_down(request_id, op) - - else: - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _send_request_down(self, request_id, op): - # Alias to avoid overload within the callback below - # CT-TODO: remove the need for this with better callback semantics - op_waiting_for_response = op - - @pipeline_thread.runs_on_pipeline_thread - def on_send_request_done(op, error): - logger.debug( - "{}({}): Finished sending {} request to {} resource {}".format( - self.name, - op_waiting_for_response.name, - op_waiting_for_response.request_type, - op_waiting_for_response.method, - op_waiting_for_response.resource_location, - ) - ) - if error: - logger.debug( - "{}({}): removing request {} from pending list".format( - self.name, op_waiting_for_response.name, request_id - ) - ) - # if there's no pending response for the given request_id, there's nothing to delete - if request_id in self.pending_responses: - del self.pending_responses[request_id] - op_waiting_for_response.complete(error=error) - else: - # NOTE: This shouldn't ever happen under normal conditions, but the following logic - # ensures that, if it does, it's handled safely. - logger.debug( - "{}({}): request_id {} not found in pending list. Unexpected behavior. Dropping".format( - self.name, op_waiting_for_response.name, request_id - ) - ) - pass - else: - # request sent. Nothing to do except wait for the response - pass - - logger.debug( - "{}({}): Sending {} request to {} resource {}".format( - self.name, op.name, op.request_type, op.method, op.resource_location - ) - ) - - new_op = pipeline_ops_base.RequestOperation( - method=op.method, - resource_location=op.resource_location, - request_body=op.request_body, - request_id=request_id, - request_type=op.request_type, - callback=on_send_request_done, - query_params=op.query_params, - ) - self.send_op_down(new_op) - - @pipeline_thread.runs_on_pipeline_thread - def _handle_pipeline_event(self, event): - if isinstance(event, pipeline_events_base.ResponseEvent): - # match ResponseEvent events to the saved dictionary of RequestAndResponseOperation - # operations which have not received responses yet. If the operation is found, - # complete it. - - logger.debug( - "{}({}): Handling event with request_id {}".format( - self.name, event.name, event.request_id - ) - ) - if event.request_id in self.pending_responses: - op = self.pending_responses[event.request_id] - del self.pending_responses[event.request_id] - op.status_code = event.status_code - op.response_body = event.response_body - op.retry_after = event.retry_after - logger.debug( - "{}({}): Completing {} request to {} resource {} with status {}".format( - self.name, - op.name, - op.request_type, - op.method, - op.resource_location, - op.status_code, - ) - ) - op.complete() - else: - logger.info( - "{}({}): request_id {} not found in pending list. Nothing to do. Dropping".format( - self.name, event.name, event.request_id - ) - ) - - elif isinstance(event, pipeline_events_base.ConnectedEvent): - """ - If we're reconnecting, send all pending requests down again. This is necessary - because any response that might have been sent by the service was possibly lost - when the connection dropped. The fact that the operation is still pending means - that we haven't received the response yet. Sending the request more than once - will result in a reasonable response for all known operations, aside from extra - processing on the server in the case of a re-sent provisioning request, or the - appearance of a jump in $version attributes in the case of a lost twin PATCH - operation. Since we're reusing the same $rid, the server, of course, _could_ - recognize that this is a duplicate request, but the behavior in this case is - undefined. - """ - - for request_id in self.pending_responses: - logger.info( - "{stage}: ConnectedEvent: re-publishing request {id} for {method} {type} ".format( - stage=self.name, - id=request_id, - method=self.pending_responses[request_id].method, - type=self.pending_responses[request_id].request_type, - ) - ) - self._send_request_down(request_id, self.pending_responses[request_id]) - - self.send_event_up(event) - - else: - self.send_event_up(event) - - -class OpTimeoutStage(PipelineStage): - """ - The purpose of the timeout stage is to add timeout errors to select operations - - The timeout_intervals attribute contains a list of operations to track along with - their timeout values. Right now this list is hard-coded but the operations and - intervals will eventually become a parameter. - - For each operation that needs a timeout check, this stage will add a timer to - the operation. If the timer elapses, this stage will fail the operation with - a OperationTimeout. The intention is that a higher stage will know what to - do with that error and act accordingly (either return the error to the user or - retry). - - This stage currently assumes that all timed out operation are just "lost". - It does not attempt to cancel the operation, as Paho doesn't have a way to - cancel an operation, and with QOS=1, sending a pub or sub twice is not - catastrophic. - - Also, as a long-term plan, the operations that need to be watched for timeout - will become an initialization parameter for this stage so that different - instances of this stage can watch for timeouts on different operations. - This will be done because we want a lower-level timeout stage which can watch - for timeouts at the MQTT level, and we want a higher-level timeout stage which - can watch for timeouts at the iothub level. In this way, an MQTT operation that - times out can be retried as an MQTT operation and a higher-level IoTHub operation - which times out can be retried as an IoTHub operation (which might necessitate - redoing multiple MQTT operations). - """ - - def __init__(self): - super().__init__() - # use a fixed list and fixed intervals for now. Later, this info will come in - # as an init param or a retry policy - self.timeout_intervals = { - pipeline_ops_mqtt.MQTTSubscribeOperation: 10, - pipeline_ops_mqtt.MQTTUnsubscribeOperation: 10, - # Only Sub and Unsub are here because MQTT auto retries pub - } - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if type(op) in self.timeout_intervals: - # Create a timer to watch for operation timeout on this op and attach it - # to the op. - self_weakref = weakref.ref(self) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def on_timeout(): - this = self_weakref() - logger.info("{}({}): returning timeout error".format(this.name, op.name)) - op.complete( - error=pipeline_exceptions.OperationTimeout( - "operation timed out before protocol client could respond" - ) - ) - - logger.debug("{}({}): Creating timer".format(self.name, op.name)) - op.timeout_timer = threading.Timer(self.timeout_intervals[type(op)], on_timeout) - op.timeout_timer.start() - - # Send the op down, but intercept the return of the op so we can - # remove the timer when the op is done - op.add_callback(self._clear_timer) - logger.debug("{}({}): Sending down".format(self.name, op.name)) - self.send_op_down(op) - else: - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _clear_timer(self, op, error): - # When an op comes back, delete the timer and pass it right up. - if op.timeout_timer: - logger.debug("{}({}): Cancelling timer".format(self.name, op.name)) - op.timeout_timer.cancel() - op.timeout_timer = None - - -class RetryStage(PipelineStage): - """ - The purpose of the retry stage is to watch specific operations for specific - errors and retry the operations as appropriate. - - Unlike the OpTimeoutStage, this stage will never need to worry about cancelling - failed operations. When an operation is retried at this stage, it is already - considered "failed", so no cancellation needs to be done. - """ - - def __init__(self): - super().__init__() - # Retry intervals are hardcoded for now. Later, they come in as an - # init param, probably via retry policy. - self.retry_intervals = { - pipeline_ops_mqtt.MQTTSubscribeOperation: 20, - pipeline_ops_mqtt.MQTTUnsubscribeOperation: 20, - # Only Sub and Unsub are here because MQTT auto retries pub - } - self.ops_waiting_to_retry = [] - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - """ - Send all ops down and intercept their return to "watch for retry" - """ - if self._should_watch_for_retry(op): - op.add_callback(self._do_retry_if_necessary) - self.send_op_down(op) - else: - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _should_watch_for_retry(self, op): - """ - Return True if this op needs to be watched for retry. This can be - called before the op runs. - """ - return type(op) in self.retry_intervals - - @pipeline_thread.runs_on_pipeline_thread - def _should_retry(self, op, error): - """ - Return True if this op needs to be retried. This must be called after - the op completes. - """ - if error: - if self._should_watch_for_retry(op): - if isinstance(error, pipeline_exceptions.OperationTimeout): - return True - return False - - @pipeline_thread.runs_on_pipeline_thread - def _do_retry_if_necessary(self, op, error): - """ - Handler which gets called when operations are complete. This function - is where we check to see if a retry is necessary and set a "retry timer" - which can be used to send the op down again. - """ - if self._should_retry(op, error): - self_weakref = weakref.ref(self) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def do_retry(): - this = self_weakref() - logger.debug("{}({}): retrying".format(this.name, op.name)) - op.retry_timer.cancel() - op.retry_timer = None - this.ops_waiting_to_retry.remove(op) - # Don't just send it down directly. Instead, go through run_op so we get - # retry functionality this time too - this.run_op(op) - - interval = self.retry_intervals[type(op)] - logger.info( - "{}({}): Op needs retry with interval {} because of {}. Setting timer.".format( - self.name, op.name, interval, error - ) - ) - - # if we don't keep track of this op, it might get collected. - op.halt_completion() - self.ops_waiting_to_retry.append(op) - op.retry_timer = threading.Timer(self.retry_intervals[type(op)], do_retry) - op.retry_timer.start() - - else: - if op.retry_timer: - op.retry_timer.cancel() - op.retry_timer = None - - -class ConnectionStateStage(PipelineStage): - - intermediate_states = [ - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTING, - ConnectionState.REAUTHORIZING, - ] - transient_connect_errors = [ - pipeline_exceptions.OperationCancelled, - pipeline_exceptions.OperationTimeout, - pipeline_exceptions.OperationError, - transport_exceptions.ConnectionFailedError, - transport_exceptions.ConnectionDroppedError, - transport_exceptions.TlsExchangeAuthError, - ] - - def __init__(self): - super().__init__() - self.reconnect_timer = None - self.waiting_ops = queue.Queue() - - # NOTE: In this stage states are both checked, and changed, but there is no lock to protect - # this state value, or the logic that surrounds it from multithreading. This is because due - # to the threading model of the pipeline, there is a dedicated pipeline thread that handles - # everything that runs here, and it can only be doing one thing at a time. Thus we don't - # need to have a threading lock on our state, or be concerned with how atomic things are. - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - - # If receiving an operation while the connection state is changing, wait for the - # connection state to reach a stable state before continuing. - if self.nucleus.connection_state in self.intermediate_states: - logger.debug( - "{}({}): State is {} - waiting for in-progress operation to finish".format( - self.name, op.name, self.nucleus.connection_state - ) - ) - self.waiting_ops.put_nowait(op) - - else: - if isinstance(op, pipeline_ops_base.ConnectOperation): - if self.nucleus.connection_state is ConnectionState.CONNECTED: - logger.debug( - "{}({}): State is already CONNECTED. Completing operation".format( - self.name, op.name - ) - ) - op.complete() - elif self.nucleus.connection_state is ConnectionState.DISCONNECTED: - logger.debug( - "{}({}): State changes DISCONNECTED -> CONNECTING. Sending op down".format( - self.name, op.name - ) - ) - self.nucleus.connection_state = ConnectionState.CONNECTING - self._add_connection_op_callback(op) - self.send_op_down(op) - else: - # This should be impossible to reach. If the state were intermediate, it - # would have been added to the waiting ops queue above. - logger.warning( - "{}({}): Invalid State - {}".format( - self.name, op.name, self.nucleus.connection_state - ) - ) - self.send_op_down(op) - - elif isinstance(op, pipeline_ops_base.DisconnectOperation): - # First, always clear any reconnect timer. Because a manual disconnection is - # occurring, we won't want to be reconnecting any more. - self._clear_reconnect_timer() - - if self.nucleus.connection_state is ConnectionState.CONNECTED: - logger.debug( - "{}({}): State changes CONNECTED -> DISCONNECTING. Sending op down.".format( - self.name, op.name - ) - ) - self.nucleus.connection_state = ConnectionState.DISCONNECTING - self._add_connection_op_callback(op) - self.send_op_down(op) - elif self.nucleus.connection_state is ConnectionState.DISCONNECTED: - logger.debug( - "{}({}): State is already DISCONNECTED. Completing operation".format( - self.name, op.name - ) - ) - op.complete() - else: - # This should be impossible to reach. If the state were intermediate, it - # would have been added to the waiting ops queue above. - logger.warning( - "{}({}): Invalid State - {}".format( - self.name, op.name, self.nucleus.connection_state - ) - ) - self.send_op_down(op) - - elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation): - if self.nucleus.connection_state is ConnectionState.CONNECTED: - logger.debug( - "{}({}): State changes CONNECTED -> REAUTHORIZING. Sending op down.".format( - self.name, op.name - ) - ) - self.nucleus.connection_state = ConnectionState.REAUTHORIZING - self._add_connection_op_callback(op) - self.send_op_down(op) - elif self.nucleus.connection_state is ConnectionState.DISCONNECTED: - logger.debug( - "{}({}): State changes DISCONNECTED -> REAUTHORIZING. Sending op down".format( - self.name, op.name - ) - ) - self.nucleus.connection_state = ConnectionState.REAUTHORIZING - self._add_connection_op_callback(op) - self.send_op_down(op) - else: - # This should be impossible to reach. If the state were intermediate, it - # would have been added to the waiting ops queue above. - logger.warning( - "{}({}): Invalid State - {}".format( - self.name, op.name, self.nucleus.connection_state - ) - ) - self.send_op_down(op) - - elif isinstance(op, pipeline_ops_base.ShutdownPipelineOperation): - self._clear_reconnect_timer() - # Cancel all pending ops so they don't hang - while not self.waiting_ops.empty(): - waiting_op = self.waiting_ops.get_nowait() - cancel_error = pipeline_exceptions.OperationCancelled( - "Operation waiting in ConnectionStateStage cancelled by shutdown" - ) - waiting_op.complete(error=cancel_error) - self.send_op_down(op) - - else: - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _handle_pipeline_event(self, event): - - if isinstance(event, pipeline_events_base.ConnectedEvent): - # First, clear the reconnect timer no matter what. - # We are now connected, so any ongoing reconnect is unnecessary - self._clear_reconnect_timer() - - # EXPECTED CONNECTION (ConnectOperation was previously issued) - if self.nucleus.connection_state is ConnectionState.CONNECTING: - logger.debug( - "{}({}): State changes CONNECTING -> CONNECTED. Connection established".format( - self.name, event.name - ) - ) - self.nucleus.connection_state = ConnectionState.CONNECTED - - # EXPECTED CONNECTION (ReauthorizeConnectionOperation was previously issued) - elif self.nucleus.connection_state is ConnectionState.REAUTHORIZING: - logger.debug( - "{}({}): State changes REAUTHORIZING -> CONNECTED. Connection re-established after re-authentication".format( - self.name, event.name - ) - ) - self.nucleus.connection_state = ConnectionState.CONNECTED - - # BAD STATE (this block should not be reached) - else: - logger.warning( - "{}: ConnectedEvent received while in unexpected state - {}".format( - self.name, self.nucleus.connection_state - ) - ) - logger.debug( - "{}({}): State changes {} -> CONNECTED. Unexpected connection".format( - self.name, event.name, self.nucleus.connection_state - ) - ) - self.nucleus.connection_state = ConnectionState.CONNECTED - - elif isinstance(event, pipeline_events_base.DisconnectedEvent): - # UNEXPECTED DISCONNECTION (i.e. Connection has been lost) - if self.nucleus.connection_state is ConnectionState.CONNECTED: - - # Set the state change before starting the timer in order to make sure - # there's no issues when the timer expires. The pipeline threading model should - # already be preventing any weirdness with timing, but can't hurt to do this - # as well. - self.nucleus.connection_state = ConnectionState.DISCONNECTED - - if self.nucleus.pipeline_configuration.connection_retry: - # When we get disconnected, we try to reconnect as soon as we can. We set a - # timer here that will start the process in another thread because we don't - # want to hold up the event flow - logger.debug( - "{}({}): State changes CONNECTED -> DISCONNECTED. Attempting to reconnect".format( - self.name, event.name - ) - ) - self._start_reconnect_timer(0.01) - else: - logger.debug( - "{}({}): State changes CONNECTED -> DISCONNECTED. Not attempting to reconnect (Connection retry disabled)".format( - self.name, event.name - ) - ) - - # EXPECTED DISCONNECTION (DisconnectOperation was previously issued) - elif self.nucleus.connection_state is ConnectionState.DISCONNECTING: - # No reconnect timer will be created. - logger.debug( - "{}({}): State changes DISCONNECTING -> DISCONNECTED. Not attempting to reconnect (User-initiated disconnect)".format( - self.name, event.name - ) - ) - self.nucleus.connection_state = ConnectionState.DISCONNECTED - - # EXPECTED DISCONNECTION (Reauthorization process) - elif self.nucleus.connection_state is ConnectionState.REAUTHORIZING: - # ConnectionState will remain REAUTHORIZING until completion of the process - # upon re-establishing the connection - - # NOTE: There is a ~small~ chance of a false positive here if an unexpected - # disconnection occurs while a ReauthorizationOperation is in flight. - # However, it will sort itself out - the ensuing connect that occurs as part - # of the reauthorization will restore connection (no harm done) or it will - # fail, at which point the failure was a result of a manual operation and - # reconnection is not supposed to occur. So either way, we end up where we want - # to be despite the false positive - just be aware that this can happen. - logger.debug( - "{}({}): Not attempting to reconnect (Reauthorization in progress)".format( - self.name, event.name - ) - ) - - # BAD STATE (this block should not be reached) - else: - logger.warning( - "{}: DisconnectEvent received while in unexpected state - {}".format( - self.name, self.nucleus.connection_state - ) - ) - logger.debug( - "{}({}): State changes {} -> DISCONNECTED. Unexpected disconnect in unexpected state".format( - self.name, event.name, self.nucleus.connection_state - ) - ) - self.nucleus.connection_state = ConnectionState.DISCONNECTED - - # In all cases the event is sent up - self.send_event_up(event) - - @pipeline_thread.runs_on_pipeline_thread - def _add_connection_op_callback(self, op): - """Adds callback to a connection op passing through to do necessary stage upkeep""" - self_weakref = weakref.ref(self) - - @pipeline_thread.runs_on_pipeline_thread - def on_complete(op, error): - this = self_weakref() - # If error, set us back to a DISCONNECTED state. It doesn't matter what kind of - # connection op this was, any failure should result in a disconnected state. - - # NOTE: Due to the stage waiting any ops if an ongoing connection op is in-progress - # as well as the way that the reconnection process checks if there is an in-progress - # connection op (and punts the reconnect if so), there is no risk here of setting - # directly to DISCONNECTED - the intermediate state being overwritten is always going - # to be due to this op that is now completing, we can be assured of that. - if error: - logger.debug( - "{}({}): failed, state change {} -> DISCONNECTED".format( - this.name, op.name, this.nucleus.connection_state - ) - ) - this.nucleus.connection_state = ConnectionState.DISCONNECTED - - # Allow the next waiting op to proceed (if any) - this._run_all_waiting_ops() - - op.add_callback(on_complete) - - @pipeline_thread.runs_on_pipeline_thread - def _run_all_waiting_ops(self): - - if not self.waiting_ops.empty(): - queuecopy = self.waiting_ops - self.waiting_ops = queue.Queue() - - while not queuecopy.empty(): - next_op = queuecopy.get_nowait() - if not next_op.completed: - logger.debug( - "{}: Resolving next waiting op: {}".format(self.name, next_op.name) - ) - self.run_op(next_op) - - @pipeline_thread.runs_on_pipeline_thread - def _reconnect(self): - self_weakref = weakref.ref(self) - - @pipeline_thread.runs_on_pipeline_thread - def on_reconnect_complete(op, error): - this = self_weakref() - if this: - logger.debug( - "{}({}): on_connect_complete error={} state={} ".format( - this.name, - op.name, - error, - this.nucleus.connection_state, - ) - ) - - if error: - # Set state back to DISCONNECTED so as not to block anything else - logger.debug( - "{}: State change {} -> DISCONNECTED".format( - this.name, this.nucleus.connection_state - ) - ) - this.nucleus.connection_state = ConnectionState.DISCONNECTED - - # report background exception to indicate this failure occurred - this.report_background_exception(error) - - # Determine if should try reconnect again - if this._should_reconnect(error): - # transient errors can cause a reconnect attempt - logger.debug( - "{}: Reconnect failed. Starting reconnection timer".format(this.name) - ) - this._start_reconnect_timer( - this.nucleus.pipeline_configuration.connection_retry_interval - ) - else: - # all others are permanent errors - logger.debug( - "{}: Cannot reconnect. Ending reconnection process".format(this.name) - ) - - # Now see if there's anything that may have blocked waiting for us to finish - this._run_all_waiting_ops() - - # NOTE: I had considered leveraging the run_op infrastructure instead of sending this - # directly down. Ultimately however, I think it's best to keep reconnects completely - # distinct from other operations that come through the pipeline - for instance, we don't - # really want them to end up queued up behind other operations via the .waiting_ops queue. - # Reconnects have a top priority. - op = pipeline_ops_base.ConnectOperation(callback=on_reconnect_complete) - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _should_reconnect(self, error): - """Returns True if a reconnect should occur in response to an error, False otherwise""" - if self.nucleus.pipeline_configuration.connection_retry: - if type(error) in self.transient_connect_errors: - return True - return False - - @pipeline_thread.runs_on_pipeline_thread - def _start_reconnect_timer(self, delay): - """ - Set a timer to reconnect after some period of time - """ - self._clear_reconnect_timer() - - self_weakref = weakref.ref(self) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def on_reconnect_timer_expired(): - this = self_weakref() - logger.debug( - "{}: Reconnect timer expired. State is {}.".format( - self.name, self.nucleus.connection_state - ) - ) - # Clear the reconnect timer here first and foremost so it doesn't accidentally - # get left around somehow. Don't use the _clear_reconnect_timer method, as the timer - # has expired, and thus cannot be cancelled. - this.reconnect_timer = None - - if this.nucleus.connection_state is ConnectionState.DISCONNECTED: - # We are still disconnected, so reconnect - - # NOTE: Because any reconnect timer would have been cancelled upon a manual - # disconnect, there is no way this block could be executing if we were happy - # with our DISCONNECTED state. - logger.debug("{}: Starting reconnection".format(this.name)) - logger.debug( - "{}: State changes {} -> CONNECTING. Sending new connect op down in reconnect attempt".format( - self.name, self.nucleus.connection_state - ) - ) - self.nucleus.connection_state = ConnectionState.CONNECTING - this._reconnect() - elif this.nucleus.connection_state in self.intermediate_states: - # If another connection op is in progress, just wait and try again later to avoid - # any extra confusion (i.e. punt the reconnection) - logger.debug( - "{}: Other connection operation in-progress, setting a new reconnection timer".format( - this.name - ) - ) - this._start_reconnect_timer( - this.nucleus.pipeline_configuration.connection_retry_interval - ) - else: - logger.debug( - "{}: Unexpected state reached ({}) after reconnection timer expired".format( - this.name, this.nucleus.connection_state - ) - ) - - self.reconnect_timer = threading.Timer(delay, on_reconnect_timer_expired) - self.reconnect_timer.start() - - @pipeline_thread.runs_on_pipeline_thread - def _clear_reconnect_timer(self): - """ - Clear any previous reconnect timer - """ - if self.reconnect_timer: - logger.debug("{}: clearing reconnect timer".format(self.name)) - self.reconnect_timer.cancel() - self.reconnect_timer = None diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py deleted file mode 100644 index 79cfbfbbd..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py +++ /dev/null @@ -1,113 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import copy -from . import ( - pipeline_ops_base, - PipelineStage, - pipeline_ops_http, - pipeline_thread, -) -from azure.iot.device.common.http_transport import HTTPTransport - -logger = logging.getLogger(__name__) - - -class HTTPTransportStage(PipelineStage): - """ - PipelineStage object which is responsible for interfacing with the HTTP protocol wrapper object. - This stage handles all HTTP operations that are not specific to IoT Hub. - """ - - def __init__(self): - super().__init__() - # The sas_token will be set when Connection Args are received - self.sas_token = None - - # The transport will be instantiated when Connection Args are received - self.transport = None - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - - # If there is a gateway hostname, use that as the hostname for connection, - # rather than the hostname itself - if self.nucleus.pipeline_configuration.gateway_hostname: - logger.debug( - "Gateway Hostname Present. Setting Hostname to: {}".format( - self.nucleus.pipeline_configuration.gateway_hostname - ) - ) - hostname = self.nucleus.pipeline_configuration.gateway_hostname - else: - logger.debug( - "Gateway Hostname not present. Setting Hostname to: {}".format( - self.nucleus.pipeline_configuration.hostname - ) - ) - hostname = self.nucleus.pipeline_configuration.hostname - - # Create HTTP Transport - logger.debug("{}({}): got connection args".format(self.name, op.name)) - self.transport = HTTPTransport( - hostname=hostname, - server_verification_cert=self.nucleus.pipeline_configuration.server_verification_cert, - x509_cert=self.nucleus.pipeline_configuration.x509, - cipher=self.nucleus.pipeline_configuration.cipher, - proxy_options=self.nucleus.pipeline_configuration.proxy_options, - ) - - self.nucleus.transport = self.transport - op.complete() - - elif isinstance(op, pipeline_ops_http.HTTPRequestAndResponseOperation): - # This will call down to the HTTP Transport with a request and also created a request callback. Because the HTTP Transport will run on the http transport thread, this call should be non-blocking to the pipeline thread. - logger.debug( - "{}({}): Generating HTTP request and setting callback before completing.".format( - self.name, op.name - ) - ) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def on_request_completed(error=None, response=None): - if error: - logger.debug( - "{}({}): Error passed to on_request_completed. Error={}".format( - self.name, op.name, error - ) - ) - op.complete(error=error) - else: - logger.debug( - "{}({}): Request completed. Completing op.".format(self.name, op.name) - ) - logger.debug("HTTP Response Status: {}".format(response["status_code"])) - logger.debug("HTTP Response: {}".format(response["resp"])) - op.response_body = response["resp"] - op.status_code = response["status_code"] - op.reason = response["reason"] - op.complete() - - # A deepcopy is necessary here since otherwise the manipulation happening to - # http_headers will affect the op.headers, which would be an unintended side effect - # and not a good practice. - http_headers = copy.deepcopy(op.headers) - if self.nucleus.pipeline_configuration.sastoken: - http_headers["Authorization"] = str(self.nucleus.pipeline_configuration.sastoken) - - self.transport.request( - method=op.method, - path=op.path, - headers=http_headers, - query_params=op.query_params, - body=op.body, - callback=on_request_completed, - ) - - else: - self.send_op_down(op) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py deleted file mode 100644 index 236ca244b..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py +++ /dev/null @@ -1,464 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import traceback -import threading -import weakref -from . import ( - pipeline_ops_base, - PipelineStage, - pipeline_ops_mqtt, - pipeline_events_mqtt, - pipeline_thread, - pipeline_exceptions, - pipeline_events_base, -) -from azure.iot.device.common.mqtt_transport import MQTTTransport -from azure.iot.device.common import handle_exceptions, transport_exceptions - -logger = logging.getLogger(__name__) - -# Maximum amount of time we wait for ConnectOperation to complete -# TODO: This whole logic of timeout should probably be handled in the TimeoutStage -WATCHDOG_INTERVAL = 60 - - -class MQTTTransportStage(PipelineStage): - """ - PipelineStage object which is responsible for interfacing with the MQTT protocol wrapper object. - This stage handles all MQTT operations and any other operations (such as ConnectOperation) which - is not in the MQTT group of operations, but can only be run at the protocol level. - """ - - def __init__(self): - super().__init__() - - # The transport will be instantiated upon receiving the InitializePipelineOperation - self.transport = None - # The current in-progress op that affects connection state (Connect, Disconnect, Reauthorize) - self._pending_connection_op = None - - @pipeline_thread.runs_on_pipeline_thread - def _cancel_pending_connection_op(self, error=None): - """ - Cancel any running connect, disconnect or reauthorize connection op. Since our ability to "cancel" is fairly limited, - all this does (for now) is to fail the operation - """ - - op = self._pending_connection_op - if op: - # NOTE: This code path should NOT execute in normal flow. There should never already be a pending - # connection op when another is added, due to the ConnectionLock stage. - # If this block does execute, there is a bug in the codebase. - if not error: - error = pipeline_exceptions.OperationCancelled( - "Cancelling because new ConnectOperation or DisconnectOperation was issued" - ) - self._cancel_connection_watchdog(op) - self._pending_connection_op = None - op.complete(error=error) - - @pipeline_thread.runs_on_pipeline_thread - def _start_connection_watchdog(self, connection_op): - """ - Start a watchdog on the connection operation. This protects against cases where transport.connect() - succeeds but the CONNACK never arrives. This is like a timeout, but it is handled at this level - because specific cleanup needs to take place on timeout (see below), and this cleanup doesn't - belong anywhere else since it is very specific to this stage. - """ - logger.debug("{}({}): Starting watchdog".format(self.name, connection_op.name)) - - self_weakref = weakref.ref(self) - op_weakref = weakref.ref(connection_op) - - @pipeline_thread.invoke_on_pipeline_thread - def watchdog_function(): - this = self_weakref() - op = op_weakref() - if this and op and this._pending_connection_op is op: - logger.info( - "{}({}): Connection watchdog expired. Cancelling op".format(this.name, op.name) - ) - try: - this.transport.disconnect() - except Exception: - # If we don't catch this, the pending connection op might not ever be cancelled. - # Most likely, the transport isn't actually connected, but other failures are theoretically - # possible. Either way, if disconnect fails, we should assume that we're disconnected. - logger.info( - "transport.disconnect raised error while disconnecting in watchdog. Safe to ignore." - ) - logger.info(traceback.format_exc()) - - if this.nucleus.connected: - - logger.info( - "{}({}): Pipeline is still connected on watchdog expiration. Sending DisconnectedEvent".format( - this.name, op.name - ) - ) - this.send_event_up(pipeline_events_base.DisconnectedEvent()) - this._cancel_pending_connection_op( - error=pipeline_exceptions.OperationTimeout( - "Transport timeout on connection operation" - ) - ) - else: - logger.debug("Connection watchdog expired, but pending op is not the same op") - - connection_op.watchdog_timer = threading.Timer(WATCHDOG_INTERVAL, watchdog_function) - connection_op.watchdog_timer.daemon = True - connection_op.watchdog_timer.start() - - @pipeline_thread.runs_on_pipeline_thread - def _cancel_connection_watchdog(self, op): - try: - if op.watchdog_timer: - logger.debug("{}({}): cancelling watchdog".format(self.name, op.name)) - op.watchdog_timer.cancel() - op.watchdog_timer = None - except AttributeError: - pass - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - - # If there is a gateway hostname, use that as the hostname for connection, - # rather than the hostname itself - if self.nucleus.pipeline_configuration.gateway_hostname: - logger.debug( - "Gateway Hostname Present. Setting Hostname to: {}".format( - self.nucleus.pipeline_configuration.gateway_hostname - ) - ) - hostname = self.nucleus.pipeline_configuration.gateway_hostname - else: - logger.debug( - "Gateway Hostname not present. Setting Hostname to: {}".format( - self.nucleus.pipeline_configuration.hostname - ) - ) - hostname = self.nucleus.pipeline_configuration.hostname - - # Create the Transport object, set it's handlers - logger.debug("{}({}): got connection args".format(self.name, op.name)) - self.transport = MQTTTransport( - client_id=op.client_id, - hostname=hostname, - username=op.username, - server_verification_cert=self.nucleus.pipeline_configuration.server_verification_cert, - x509_cert=self.nucleus.pipeline_configuration.x509, - websockets=self.nucleus.pipeline_configuration.websockets, - cipher=self.nucleus.pipeline_configuration.cipher, - proxy_options=self.nucleus.pipeline_configuration.proxy_options, - keep_alive=self.nucleus.pipeline_configuration.keep_alive, - ) - self.transport.on_mqtt_connected_handler = self._on_mqtt_connected - self.transport.on_mqtt_connection_failure_handler = self._on_mqtt_connection_failure - self.transport.on_mqtt_disconnected_handler = self._on_mqtt_disconnected - self.transport.on_mqtt_message_received_handler = self._on_mqtt_message_received - - # There can only be one pending connection operation (Connect, Disconnect) - # at a time. The existing one must be completed or canceled before a new one is set. - - # Currently, this means that if, say, a connect operation is the pending op and is executed - # but another connection op is begins by the time the CONNACK is received, the original - # operation will be cancelled, but the CONNACK for it will still be received, and complete the - # NEW operation. This is not desirable, but it is how things currently work. - - # We are however, checking the type, so the CONNACK from a cancelled Connect, cannot successfully - # complete a Disconnect operation. - - # Note that a ReauthorizeConnectionOperation will never be pending because it will - # instead spawn separate Connect and Disconnect operations. - self._pending_connection_op = None - - op.complete() - - elif isinstance(op, pipeline_ops_base.ShutdownPipelineOperation): - try: - self.transport.shutdown() - except Exception as e: - logger.info("transport.shutdown raised error") - logger.info(traceback.format_exc()) - op.complete(error=e) - else: - op.complete() - - elif isinstance(op, pipeline_ops_base.ConnectOperation): - logger.debug("{}({}): connecting".format(self.name, op.name)) - - self._cancel_pending_connection_op() - self._pending_connection_op = op - self._start_connection_watchdog(op) - # Use SasToken as password if present. If not present (e.g. using X509), - # then no password is required because auth is handled via other means. - if self.nucleus.pipeline_configuration.sastoken: - password = str(self.nucleus.pipeline_configuration.sastoken) - else: - password = None - try: - self.transport.connect(password=password) - except Exception as e: - logger.info("transport.connect raised error") - logger.info(traceback.format_exc()) - self._cancel_connection_watchdog(op) - self._pending_connection_op = None - op.complete(error=e) - - elif isinstance(op, pipeline_ops_base.DisconnectOperation): - logger.debug("{}({}): disconnecting".format(self.name, op.name)) - - self._cancel_pending_connection_op() - self._pending_connection_op = op - # We don't need a watchdog on disconnect because there's no callback to wait for - # and we respond to a watchdog timeout by calling disconnect, which is what we're - # already doing. - - try: - # The connect after the disconnect will be triggered upon completion of the - # disconnect in the on_disconnected handler - self.transport.disconnect(clear_inflight=op.hard) - except Exception as e: - logger.info("transport.disconnect raised error while disconnecting") - logger.info(traceback.format_exc()) - self._pending_connection_op = None - op.complete(error=e) - - elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation): - logger.debug( - "{}({}): reauthorizing. Will issue disconnect and then a connect".format( - self.name, op.name - ) - ) - self_weakref = weakref.ref(self) - reauth_op = op # rename for clarity - - def on_disconnect_complete(op, error): - this = self_weakref() - if error: - # Failing a disconnect should still get us disconnected, so can proceed anyway - logger.debug( - "Disconnect failed during reauthorization, continuing with connect" - ) - connect_op = reauth_op.spawn_worker_op(pipeline_ops_base.ConnectOperation) - - # NOTE: this relies on the fact that before the disconnect is completed it is - # unset as the pending connection op. Otherwise there would be issues here. - this.run_op(connect_op) - - disconnect_op = pipeline_ops_base.DisconnectOperation(callback=on_disconnect_complete) - disconnect_op.hard = False - - self.run_op(disconnect_op) - - elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation): - logger.debug("{}({}): publishing on {}".format(self.name, op.name, op.topic)) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def on_complete(cancelled=False): - if cancelled: - op.complete( - error=pipeline_exceptions.OperationCancelled( - "Operation cancelled before PUBACK received" - ) - ) - else: - logger.debug( - "{}({}): PUBACK received. completing op.".format(self.name, op.name) - ) - op.complete() - - try: - self.transport.publish(topic=op.topic, payload=op.payload, callback=on_complete) - except Exception as e: - op.complete(error=e) - - elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation): - logger.debug("{}({}): subscribing to {}".format(self.name, op.name, op.topic)) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def on_complete(cancelled=False): - if cancelled: - op.complete( - error=pipeline_exceptions.OperationCancelled( - "Operation cancelled before SUBACK received" - ) - ) - else: - logger.debug( - "{}({}): SUBACK received. completing op.".format(self.name, op.name) - ) - op.complete() - - try: - self.transport.subscribe(topic=op.topic, callback=on_complete) - except Exception as e: - op.complete(error=e) - - elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation): - logger.debug("{}({}): unsubscribing from {}".format(self.name, op.name, op.topic)) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def on_complete(cancelled=False): - if cancelled: - op.complete( - error=pipeline_exceptions.OperationCancelled( - "Operation cancelled before UNSUBACK received" - ) - ) - else: - logger.debug( - "{}({}): UNSUBACK received. completing op.".format(self.name, op.name) - ) - op.complete() - - try: - self.transport.unsubscribe(topic=op.topic, callback=on_complete) - except Exception as e: - op.complete(error=e) - - else: - # This code block should not be reached in correct program flow. - # This will raise an error when executed. - self.send_op_down(op) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def _on_mqtt_message_received(self, topic, payload): - """ - Handler that gets called by the protocol library when an incoming message arrives. - Convert that message into a pipeline event and pass it up for someone to handle. - """ - logger.debug("{}: message received on topic {}".format(self.name, topic)) - self.send_event_up( - pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=payload) - ) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def _on_mqtt_connected(self): - """ - Handler that gets called by the transport when it connects. - """ - logger.info("_on_mqtt_connected called") - # Send an event to tell other pipeline stages that we're connected. Do this before - # we do anything else (in case upper stages have any "are we connected" logic. - self.send_event_up(pipeline_events_base.ConnectedEvent()) - - if isinstance(self._pending_connection_op, pipeline_ops_base.ConnectOperation): - logger.debug("{}: completing connect op".format(self.name)) - op = self._pending_connection_op - self._cancel_connection_watchdog(op) - self._pending_connection_op = None - op.complete() - else: - # This should indicate something odd is going on. - # If this occurs, either a connect was completed while there was no pending op, - # OR that a connect was completed while a disconnect op was pending - logger.info( - "{}: Connection was unexpected (no connection op pending)".format(self.name) - ) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def _on_mqtt_connection_failure(self, cause): - """ - Handler that gets called by the transport when a connection fails. - - :param Exception cause: The Exception that caused the connection failure. - """ - - logger.info("{}: _on_mqtt_connection_failure called: {}".format(self.name, cause)) - - if isinstance(self._pending_connection_op, pipeline_ops_base.ConnectOperation): - logger.debug("{}: failing connect op".format(self.name)) - op = self._pending_connection_op - self._cancel_connection_watchdog(op) - self._pending_connection_op = None - op.complete(error=cause) - else: - logger.debug("{}: Connection failure was unexpected".format(self.name)) - handle_exceptions.swallow_unraised_exception( - cause, - log_msg="Unexpected connection failure (no pending operation). Safe to ignore.", - log_lvl="info", - ) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def _on_mqtt_disconnected(self, cause=None): - """ - Handler that gets called by the transport when the transport disconnects. - - :param Exception cause: The Exception that caused the disconnection, if any (optional) - """ - if cause: - logger.info("{}: _on_mqtt_disconnect called: {}".format(self.name, cause)) - else: - logger.info("{}: _on_mqtt_disconnect called".format(self.name)) - - # Send an event to tell other pipeline stages that we're disconnected. Do this before - # we do anything else (in case upper stages have any "are we connected" logic.) - # NOTE: Other stages rely on the fact that this occurs before any op that may be in - # progress is completed. Be careful with changing the order things occur here. - self.send_event_up(pipeline_events_base.DisconnectedEvent()) - - if self._pending_connection_op: - - op = self._pending_connection_op - - if isinstance(op, pipeline_ops_base.DisconnectOperation): - logger.debug( - "{}: Expected disconnect - completing pending disconnect op".format(self.name) - ) - # Swallow any errors if we intended to disconnect - even if something went wrong, we - # got to the state we wanted to be in! - if cause: - handle_exceptions.swallow_unraised_exception( - cause, - log_msg="Unexpected error while disconnecting - swallowing error", - ) - # Disconnect complete, no longer pending - self._pending_connection_op = None - op.complete() - - else: - logger.debug( - "{}: Unexpected disconnect - completing pending {} operation".format( - self.name, op.name - ) - ) - # Cancel any potential connection watchdog, and clear the pending op - self._cancel_connection_watchdog(op) - self._pending_connection_op = None - # Complete - if cause: - op.complete(error=cause) - else: - op.complete( - error=transport_exceptions.ConnectionDroppedError("transport disconnected") - ) - else: - logger.info("{}: Unexpected disconnect (no pending connection op)".format(self.name)) - - # If there is no connection retry, cancel any transport operations waiting on response - # so that they do not get stuck there. - if not self.nucleus.pipeline_configuration.connection_retry: - logger.debug( - "{}: Connection Retry disabled - cancelling in-flight operations".format( - self.name - ) - ) - # TODO: Remove private access to the op manager (this layer shouldn't know about it) - # This is a stopgap. I didn't want to invest too much infrastructure into a cancel flow - # given that future development of individual operation cancels might affect the - # approach to cancelling inflight ops waiting in the transport. - self.transport._op_manager.cancel_all_operations() - - # Regardless of cause, it is now a ConnectionDroppedError. Log it and swallow it. - # Higher layers will see that we're disconnected and may reconnect as necessary. - e = transport_exceptions.ConnectionDroppedError("Unexpected disconnection") - e.__cause__ = cause - self.report_background_exception(e) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py deleted file mode 100644 index e431b3681..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py +++ /dev/null @@ -1,208 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import functools -import logging -import threading -import traceback -from concurrent.futures import ThreadPoolExecutor -from azure.iot.device.common import handle_exceptions - -logger = logging.getLogger(__name__) - -""" -This module contains decorators that are used to marshal code into pipeline and -callback threads and to assert that code is being called in the correct thread. - -The intention of these decorators is to ensure the following: - -1. All pipeline functions execute in a single thread, known as the "pipeline - thread". The `invoke_on_pipeline_thread` and `invoke_on_pipeline_thread_nowait` - decorators cause the decorated function to run on the pipeline thread. - -2. If the pipeline thread is busy running a different function, the invoke - decorators will wait until that function is complete before invoking another - function on that thread. - -3. There is a different thread which is used for callbacks into user code, known - as the the "callback thread". This is not meant for callbacks into pipeline - code. Those callbacks should still execute on the pipeline thread. The - `invoke_on_callback_thread_nowait` decorator is used to ensure that callbacks - execute on the callback thread. - -4. Decorators which cause thread switches are used only when necessary. The - pipeline thread is only entered in places where we know that external code is - calling into the pipeline (such as a client API call or a callback from a - third-party library). Likewise, the callback thread is only entered in places - where we know that the pipeline is calling back into client code. - -5. Exceptions raised from the pipeline thread are still able to be caught by - the function which entered the pipeline thread. - -5. Calls into the pipeline thread can either block or not block. Blocking is used - for cases where the caller needs a return value from the pipeline or is - expecting to handle any errors raised from the pipeline thread. Blocking is - not used when the code calling into the pipeline is not waiting for a response - and is not expecting to handle any exceptions, such as protocol library - handlers which call into the pipeline to deliver protocol messages. - -6. Calls into the callback thread could theoretically block, but we currently - only have decorators which enter the callback thread without blocking. This - is done to ensure that client code does not execute on the pipeline thread and - also to ensure that the pipeline thread is not blocked while waiting for client - code to execute. - -These decorators use concurrent.futures.Future and the ThreadPoolExecutor because: - -1. The thread pooling with a pool size of 1 gives us a single thread to run all - pipeline operations and a different (single) thread to run all callbacks. If - the code attempts to run a second pipeline operation (or callback) while a - different one is running, the ThreadPoolExecutor will queue the code until the - first call is completed. - -2. The concurrent.futures.Future object properly handles both Exception and - BaseException errors, re-raising them when the Future.result method is called. - threading.Thread.get() was not an option because it doesn't re-raise - BaseException errors when Thread.get is called. -""" - -_executors = {} - - -def _get_named_executor(thread_name): - """ - Get a ThreadPoolExecutor object with the given name. If no such executor exists, - this function will create on with a single worker and assign it to the provided - name. - """ - global _executors - if thread_name not in _executors: - logger.debug("Creating {} executor".format(thread_name)) - _executors[thread_name] = ThreadPoolExecutor(max_workers=1) - return _executors[thread_name] - - -def _invoke_on_executor_thread(func, thread_name, block=True): - """ - Return wrapper to run the function on a given thread. If block==False, - the call returns immediately without waiting for the decorated function to complete. - If block==True, the call waits for the decorated function to complete before returning. - """ - - # Mocks and other callable objects don't have a __name__ attribute. - # Use str() if you can't use __name__ - try: - function_name = func.__name__ - except AttributeError: - function_name = str(func) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - if threading.current_thread().name is not thread_name: - logger.debug("Starting {} in {} thread".format(function_name, thread_name)) - - def thread_proc(): - threading.current_thread().name = thread_name - try: - return func(*args, **kwargs) - except Exception as e: - if not block: - handle_exceptions.handle_background_exception(e) - else: - raise - except BaseException: - if not block: - # This is truly a logger.critical condition. Most exceptions in background threads should - # be handled inside the thread and should result in call to handle_background_exception - # if this code is hit, that means something happened which wasn't handled, therefore - # handle_background_exception wasn't called, therefore we need to log this at the highest - # level. - logger.critical("Unhandled exception in background thread") - logger.critical( - "This may cause the background thread to abort and may result in system instability." - ) - traceback.print_exc() - raise - - # TODO: add a timeout here and throw exception on failure - future = _get_named_executor(thread_name).submit(thread_proc) - if block: - return future.result() - else: - return future - else: - logger.debug("Already in {} thread for {}".format(thread_name, function_name)) - return func(*args, **kwargs) - - return wrapper - - -def invoke_on_pipeline_thread(func): - """ - Run the decorated function on the pipeline thread. - """ - return _invoke_on_executor_thread(func=func, thread_name="pipeline") - - -def invoke_on_pipeline_thread_nowait(func): - """ - Run the decorated function on the pipeline thread, but don't wait for it to complete - """ - return _invoke_on_executor_thread(func=func, thread_name="pipeline", block=False) - - -def invoke_on_callback_thread_nowait(func): - """ - Run the decorated function on the callback thread, but don't wait for it to complete - """ - return _invoke_on_executor_thread(func=func, thread_name="callback", block=False) - - -def invoke_on_http_thread_nowait(func): - """ - Run the decorated function on the callback thread, but don't wait for it to complete - """ - # TODO: Refactor this since this is not in the pipeline thread anymore, so we need to pull this into common. - # Also, the max workers eventually needs to be a bigger number, so that needs to be fixed to allow for more than one HTTP Request a a time. - return _invoke_on_executor_thread(func=func, thread_name="azure_iot_http", block=False) - - -def _assert_executor_thread(func, thread_name): - """ - Decorator which asserts that the given function only gets called inside the given - thread. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - - assert ( - threading.current_thread().name == thread_name - ), """ - Function {function_name} is not running inside {thread_name} thread. - It should be. You should use invoke_on_{thread_name}_thread(_nowait) to enter the - {thread_name} thread before calling this function. If you're hitting this from - inside a test function, you may need to add the fake_pipeline_thread fixture to - your test. (generally applied on the global pytestmark in a module) """.format( - function_name=func.__name__, thread_name=thread_name - ) - - return func(*args, **kwargs) - - return wrapper - - -def runs_on_pipeline_thread(func): - """ - Decorator which marks a function as only running inside the pipeline thread. - """ - return _assert_executor_thread(func=func, thread_name="pipeline") - - -def runs_on_http_thread(func): - """ - Decorator which marks a function as only running inside the http thread. - """ - return _assert_executor_thread(func=func, thread_name="azure_iot_http") diff --git a/azure-iot-device/azure/iot/device/common/transport_exceptions.py b/azure-iot-device/azure/iot/device/common/transport_exceptions.py deleted file mode 100644 index 1c572ec59..000000000 --- a/azure-iot-device/azure/iot/device/common/transport_exceptions.py +++ /dev/null @@ -1,62 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module defines errors that may be raised from a transport""" - - -class ConnectionFailedError(Exception): - """ - Connection failed to be established - """ - - pass - - -class ConnectionDroppedError(Exception): - """ - Previously established connection was dropped - """ - - pass - - -class NoConnectionError(Exception): - """ - There is no connection - """ - - -class UnauthorizedError(Exception): - """ - Authorization was rejected - """ - - pass - - -class ProtocolClientError(Exception): - """ - Error returned from protocol client library - """ - - pass - - -class TlsExchangeAuthError(Exception): - """ - Error returned when transport layer exchanges - result in a SSLCertVerification error. - """ - - pass - - -class ProtocolProxyError(Exception): - """ - All proxy-related errors. - TODO : Not sure what to name it here. There is a class called Proxy Error already in Pysocks - """ - - pass diff --git a/azure-iot-device/azure/iot/device/config.py b/azure-iot-device/azure/iot/device/config.py new file mode 100644 index 000000000..2a62c7578 --- /dev/null +++ b/azure-iot-device/azure/iot/device/config.py @@ -0,0 +1,182 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import socks +import ssl +from typing import Optional, Any +from .sastoken import SasTokenProvider + +# TODO: add typings for imports +# TODO: update docs to ensure types are correct +# TODO: can these just be TypeDicts? + + +logger = logging.getLogger(__name__) + +# The max keep alive is determined by the load balancer currently. +MAX_KEEP_ALIVE_SECS = 1740 + + +string_to_socks_constant_map = {"HTTP": socks.HTTP, "SOCKS4": socks.SOCKS4, "SOCKS5": socks.SOCKS5} +socks_constant_to_string_map = {socks.HTTP: "HTTP", socks.SOCKS4: "SOCKS4", socks.SOCKS5: "SOCKS5"} + + +class ProxyOptions: + """ + A class containing various options to send traffic through proxy servers by enabling + proxying of MQTT connection. + """ + + def __init__( + self, + proxy_type: str, + proxy_address: str, + proxy_port: Optional[int] = None, + proxy_username: Optional[str] = None, + proxy_password: Optional[str] = None, + ): + """ + Initializer for proxy options. + :param str proxy_type: The type of the proxy server. This can be one of three possible choices: "HTTP", "SOCKS4", or "SOCKS5" + :param str proxy_addr: IP address or DNS name of proxy server + :param int proxy_port: The port of the proxy server. Defaults to 1080 for socks and 8080 for http. + :param str proxy_username: (optional) username for SOCKS5 proxy, or userid for SOCKS4 proxy.This parameter is ignored if an HTTP server is being used. + If it is not provided, authentication will not be used (servers may accept unauthenticated requests). + :param str proxy_password: (optional) This parameter is valid only for SOCKS5 servers and specifies the respective password for the username provided. + """ + # TODO: port default + # TODO: is that documentation about auth only being used on SOCKS accurate? Seems inaccurate. + (self.proxy_type, self.proxy_type_socks) = _format_proxy_type(proxy_type) + self.proxy_address = proxy_address + if proxy_port is None: + self.proxy_port = _derive_default_proxy_port(self.proxy_type) + else: + self.proxy_port = int(proxy_port) + self.proxy_username = proxy_username + self.proxy_password = proxy_password + + +class ClientConfig: + """ + Class for storing all configurations/options shared across the + Azure IoT Python Device Client Library. + """ + + def __init__( + self, + *, + ssl_context: ssl.SSLContext, + hostname: str, + sastoken_provider: Optional[SasTokenProvider] = None, + proxy_options: Optional[ProxyOptions] = None, + keep_alive: int = 60, + auto_reconnect: bool = True, + websockets: bool = False, + ) -> None: + """Initializer for ClientConfig + + :param str hostname: The hostname being connected to + :param sastoken_provider: Object that can provide SasTokens + :type sastoken_provider: :class:`SasTokenProvider` + :param proxy_options: Details of proxy configuration + :type proxy_options: :class:`azure.iot.device.common.models.ProxyOptions` + :param ssl_context: SSLContext to use with the client + :type ssl_context: :class:`ssl.SSLContext` + :param int keepalive: Maximum period in seconds between communications with the + broker. + :param bool auto_reconnect: Indicates if dropped connection should result in attempts to + re-establish it + :param bool websockets: Enabling/disabling websockets in MQTT. This feature is relevant + if a firewall blocks port 8883 from use. + """ + # Network + self.hostname = hostname + self.proxy_options = proxy_options + + # Auth + self.sastoken_provider = sastoken_provider + self.ssl_context = ssl_context + + # MQTT + self.keep_alive = _sanitize_keep_alive(keep_alive) + self.auto_reconnect = auto_reconnect + self.websockets = websockets + + +class IoTHubClientConfig(ClientConfig): + def __init__( + self, + *, + device_id: str, + module_id: Optional[str] = None, + product_info: str = "", + **kwargs: Any, + ) -> None: + """ + Config object used for IoTHub clients, containing all relevant details and options. + + :param str device_id: The device identity being used with the IoTHub + :param str module_id: The module identity being used with the IoTHub + :param str product_info: A custom identification string. + + Additional parameters found in the docstring of the parent class + """ + self.device_id = device_id + self.module_id = module_id + self.product_info = product_info + super().__init__(**kwargs) + + +class ProvisioningClientConfig(ClientConfig): + def __init__(self, *, registration_id: str, id_scope: str, **kwargs) -> None: + """ + Config object used for Provisioning clients, containing all relevant details and options. + + :param str registration_id: The device registration identity being provisioned + :param str id_scope: The identity of the provisioning service being used + """ + self.registration_id = registration_id + self.id_scope = id_scope + super().__init__(**kwargs) + + +# Sanitization # + + +def _format_proxy_type(proxy_type): + """Returns a tuple of formats for proxy type (string, socks library constant)""" + try: + return (proxy_type, string_to_socks_constant_map[proxy_type]) + except KeyError: + # Backwards compatibility for when we used the socks library constants in the API + try: + return (socks_constant_to_string_map[proxy_type], proxy_type) + except KeyError: + raise ValueError("Invalid Proxy Type") + + +def _derive_default_proxy_port(proxy_type): + if proxy_type == "HTTP": + return 8080 + else: + return 1080 + + +def _sanitize_keep_alive(keep_alive): + try: + keep_alive = int(keep_alive) + except (ValueError, TypeError): + raise TypeError("Invalid type for 'keep alive'. Must be a numeric value.") + + if keep_alive <= 0: + # Not allowing a keep alive of 0 as this would mean frequent ping exchanges. + raise ValueError("'keep alive' must be greater than 0") + + if keep_alive > MAX_KEEP_ALIVE_SECS: + raise ValueError("'keep_alive' cannot exceed 1740 seconds (29 minutes)") + + return keep_alive diff --git a/azure-iot-device/azure/iot/device/common/auth/connection_string.py b/azure-iot-device/azure/iot/device/connection_string.py similarity index 95% rename from azure-iot-device/azure/iot/device/common/auth/connection_string.py rename to azure-iot-device/azure/iot/device/connection_string.py index 6960d8392..cc43ee2c6 100644 --- a/azure-iot-device/azure/iot/device/common/auth/connection_string.py +++ b/azure-iot-device/azure/iot/device/connection_string.py @@ -30,6 +30,44 @@ X509, ] +# TODO: does this module need revision for V3? + + +class ConnectionString(object): + """Key/value mappings for connection details. + Uses the same syntax as dictionary + """ + + def __init__(self, connection_string): + """Initializer for ConnectionString + + :param str connection_string: String with connection details provided by Azure + :raises: ValueError if provided connection_string is invalid + """ + self._dict = _parse_connection_string(connection_string) + self._strrep = connection_string + + def __contains__(self, item): + return item in self._dict + + def __getitem__(self, key): + return self._dict[key] + + def __repr__(self): + return self._strrep + + def get(self, key, default=None): + """Return the value for key if key is in the dictionary, else default + + :param str key: The key to retrieve a value for + :param str default: The default value returned if a key is not found + :returns: The value for the given key + """ + try: + return self._dict[key] + except KeyError: + return default + def _parse_connection_string(connection_string): """Return a dictionary of values contained in a given connection string""" @@ -60,7 +98,7 @@ def _validate_keys(d): device_id = d.get(DEVICE_ID) x509 = d.get(X509) - if shared_access_key and x509: + if shared_access_key and x509 and x509.lower() == "true": raise ValueError("Invalid Connection String - Mixed authentication scheme") # This logic could be expanded to return the category of ConnectionString @@ -70,36 +108,3 @@ def _validate_keys(d): pass else: raise ValueError("Invalid Connection String - Incomplete") - - -class ConnectionString(object): - """Key/value mappings for connection details. - Uses the same syntax as dictionary - """ - - def __init__(self, connection_string): - """Initializer for ConnectionString - - :param str connection_string: String with connection details provided by Azure - :raises: ValueError if provided connection_string is invalid - """ - self._dict = _parse_connection_string(connection_string) - self._strrep = connection_string - - def __getitem__(self, key): - return self._dict[key] - - def __repr__(self): - return self._strrep - - def get(self, key, default=None): - """Return the value for key if key is in the dictionary, else default - - :param str key: The key to retrieve a value for - :param str default: The default value returned if a key is not found - :returns: The value for the given key - """ - try: - return self._dict[key] - except KeyError: - return default diff --git a/azure-iot-device/azure/iot/device/constant.py b/azure-iot-device/azure/iot/device/constant.py index dc69935bc..d572b6267 100644 --- a/azure-iot-device/azure/iot/device/constant.py +++ b/azure-iot-device/azure/iot/device/constant.py @@ -6,17 +6,18 @@ """This module defines constants for use across the azure-iot-device package """ -VERSION = "2.12.0" +VERSION = "3.0.0b2" IOTHUB_IDENTIFIER = "azure-iot-device-iothub-py" PROVISIONING_IDENTIFIER = "azure-iot-device-provisioning-py" -IOTHUB_API_VERSION = "2019-10-01" +IOTHUB_API_VERSION = "2020-09-30" PROVISIONING_API_VERSION = "2019-03-31" SECURITY_MESSAGE_INTERFACE_ID = "urn:azureiot:Security:SecurityAgent:1" +PROVISIONING_GLOBAL_ENDPOINT = "global.azure-devices-provisioning.net" + +# TODO: find somewhere else for this TELEMETRY_MESSAGE_SIZE_LIMIT = 262144 -# The max keep alive is determined by the load balancer currently. -MAX_KEEP_ALIVE_SECS = 1740 + # Everything in digital twin is defined here # as things are extremely dynamic and subject to sudden changes DIGITAL_TWIN_PREFIX = "dtmi" -DIGITAL_TWIN_API_VERSION = "2020-09-30" DIGITAL_TWIN_QUERY_HEADER = "model-id" diff --git a/azure-iot-device/azure/iot/device/custom_typing.py b/azure-iot-device/azure/iot/device/custom_typing.py new file mode 100644 index 000000000..8a41dd3a3 --- /dev/null +++ b/azure-iot-device/azure/iot/device/custom_typing.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from typing import Union, Dict, List, Tuple, Callable, Awaitable, TypeVar +from typing_extensions import TypedDict, ParamSpec + + +_P = ParamSpec("_P") +_R = TypeVar("_R") +FunctionOrCoroutine = Union[Callable[_P, _R], Callable[_P, Awaitable[_R]]] + + +# typing does not support recursion, so we must use forward references here (PEP484) +JSONSerializable = Union[ + Dict[str, "JSONSerializable"], + List["JSONSerializable"], + Tuple["JSONSerializable", ...], + str, + int, + float, + bool, + None, +] +# TODO: verify that the JSON specification requires str as keys in dict. Not sure why that's defined here. + + +Twin = Dict[str, Dict[str, JSONSerializable]] +TwinPatch = Dict[str, JSONSerializable] + + +class DirectMethodParameters(TypedDict): + methodName: str + payload: JSONSerializable + connectTimeoutInSeconds: int + responseTimeoutInSeconds: int + + +class DirectMethodResult(TypedDict): + status: int + payload: JSONSerializable + + +class StorageInfo(TypedDict): + correlationId: str + hostName: str + containerName: str + blobName: str + sasToken: str + + +class RegistrationState(TypedDict): + deviceId: str + assignedHub: str + subStatus: str + createdDateTimeUtc: str + lastUpdatedDateTimeUtc: str + etag: str + payload: JSONSerializable + + +class DeviceRegistrationRequest(TypedDict): + registrationId: str + payload: JSONSerializable + # TODO csr + + +class RegistrationResult(TypedDict): + operationId: str + status: str + registrationState: RegistrationState diff --git a/azure-iot-device/azure/iot/device/iothub/edge_hsm.py b/azure-iot-device/azure/iot/device/edge_hsm.py similarity index 86% rename from azure-iot-device/azure/iot/device/iothub/edge_hsm.py rename to azure-iot-device/azure/iot/device/edge_hsm.py index a443994f7..7654767f1 100644 --- a/azure-iot-device/azure/iot/device/iothub/edge_hsm.py +++ b/azure-iot-device/azure/iot/device/edge_hsm.py @@ -1,26 +1,23 @@ -# -------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - +import base64 import logging import json -import base64 -import requests -import requests_unixsocket -import urllib -from azure.iot.device.common.auth.signing_mechanism import SigningMechanism -from azure.iot.device import user_agent +import requests # type: ignore +import requests_unixsocket # type: ignore +import urllib.parse +from typing import Union +from . import user_agent +from .exceptions import IoTEdgeError +from .signing_mechanism import SigningMechanism requests_unixsocket.monkeypatch() logger = logging.getLogger(__name__) -class IoTEdgeError(Exception): - pass - - class IoTEdgeHsm(SigningMechanism): """ Constructor for instantiating a iot hsm object. This is an object that @@ -34,7 +31,9 @@ class IoTEdgeHsm(SigningMechanism): SharedAccessSignature string which can be used to authenticate with Iot Edge """ - def __init__(self, module_id, generation_id, workload_uri, api_version): + def __init__( + self, module_id: str, generation_id: str, workload_uri: str, api_version: str + ) -> None: """ Constructor for instantiating a Azure IoT Edge HSM object @@ -48,7 +47,8 @@ def __init__(self, module_id, generation_id, workload_uri, api_version): self.generation_id = generation_id self.workload_uri = _format_socket_uri(workload_uri) - def get_certificate(self): + # TODO: Use async http to make use of this being a coroutine + async def get_certificate(self) -> str: """ Return the server verification certificate from the trust bundle that can be used to validate the server-side SSL TLS connection that we use to talk to Edge @@ -80,7 +80,8 @@ def get_certificate(self): raise IoTEdgeError("No certificate in trust bundle") from e return cert - def sign(self, data_str): + # TODO: Use async http to make use of this being a coroutine + async def sign(self, data_str: Union[str, bytes]) -> str: """ Use the IoTEdge HSM to sign a piece of string data. The caller should then insert the returned value (the signature) into the 'sig' field of a SharedAccessSignature string. @@ -92,7 +93,13 @@ def sign(self, data_str): :raises: IoTEdgeError if unable to sign the data. """ - encoded_data_str = base64.b64encode(data_str.encode("utf-8")).decode() + # Convert data_str to bytes (if not already) + if isinstance(data_str, str): + data_bytes = data_str.encode("utf-8") + else: + data_bytes = data_str + + encoded_data_str = base64.b64encode(data_bytes).decode() path = "{workload_uri}modules/{module_id}/genid/{gen_id}/sign".format( workload_uri=self.workload_uri, module_id=self.module_id, gen_id=self.generation_id @@ -121,7 +128,7 @@ def sign(self, data_str): return signed_data_str # what format is this? string? bytes? -def _format_socket_uri(old_uri): +def _format_socket_uri(old_uri: str) -> str: """ This function takes a socket URI in one form and converts it into another form. diff --git a/azure-iot-device/azure/iot/device/exceptions.py b/azure-iot-device/azure/iot/device/exceptions.py index 0b1f8622c..3f243589f 100644 --- a/azure-iot-device/azure/iot/device/exceptions.py +++ b/azure-iot-device/azure/iot/device/exceptions.py @@ -3,183 +3,49 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -"""This module defines an exception surface, exposed as part of the azure.iot.device library API""" +"""Define Azure IoT domain user-facing exceptions to be shared across package""" +from .mqtt_client import ( # noqa: F401 (Importing directly to re-export) + MQTTError, + MQTTConnectionFailedError, + MQTTConnectionDroppedError, +) -# Currently, we are redefining many lower level exceptions in this file, in order to present an API -# surface that will be consistent and unchanging (even though lower level exceptions may change). -# Potentially, this could be somewhat relaxed in the future as the design solidifies. -# ~~~ EXCEPTIONS ~~~ - - -class OperationCancelled(Exception): - """An operation was cancelled""" - - pass - - -class OperationTimeout(Exception): - """An operation timed out""" - - pass - - -# ~~~ CLIENT ERRORS ~~~ - - -class ClientError(Exception): - """Generic error for a client""" +# Client/Session Exceptions +# TODO: Should this be here? Only if HTTP stack still exists. If not, move to specific file +# TODO: Should this just be a generic ClientError that could be used across clients? +class IoTHubClientError(Exception): + """Represents a failure from the IoTHub Client""" pass -class ConnectionFailedError(ClientError): - """Failed to establish a connection""" +class SessionError(Exception): + """Represents a failure from the Session object""" pass -class ConnectionDroppedError(ClientError): - """Lost connection while executing operation""" +# Service Exceptions +class IoTHubError(Exception): + """Represents a failure reported by IoT Hub""" pass -class NoConnectionError(ClientError): - """Operation could not be completed because no connection has been established""" +class IoTEdgeError(Exception): + """Represents a failure reported by IoT Edge""" pass -class CredentialError(ClientError): - """Could not connect client using given credentials""" +class ProvisioningServiceError(Exception): + """Represents a failure reported by Provisioning Service""" pass -# ~~~ SERVICE ERRORS ~~~ - - -class ServiceError(Exception): - """Error received from an Azure IoT service""" +class IoTEdgeEnvironmentError(Exception): + """Represents a failure retrieving data from the IoT Edge environment""" pass - - -# NOTE: These are not (yet) in use. -# Because of this they have been commented out to prevent confusion. - -# class ArgumentError(ServiceError): -# """Service returned 400""" - -# pass - - -# class UnauthorizedError(ServiceError): -# """Service returned 401""" - -# pass - - -# class QuotaExceededError(ServiceError): -# """Service returned 403""" - -# pass - - -# class NotFoundError(ServiceError): -# """Service returned 404""" - -# pass - - -# class DeviceTimeoutError(ServiceError): -# """Service returned 408""" - -# # TODO: is this a method call error? If so, do we retry? -# pass - - -# class DeviceAlreadyExistsError(ServiceError): -# """Service returned 409""" - -# pass - - -# class InvalidEtagError(ServiceError): -# """Service returned 412""" - -# pass - - -# class MessageTooLargeError(ServiceError): -# """Service returned 413""" - -# pass - - -# class ThrottlingError(ServiceError): -# """Service returned 429""" - -# pass - - -# class InternalServiceError(ServiceError): -# """Service returned 500""" - -# pass - - -# class BadDeviceResponseError(ServiceError): -# """Service returned 502""" - -# # TODO: is this a method invoke thing? -# pass - - -# class ServiceUnavailableError(ServiceError): -# """Service returned 503""" - -# pass - - -# class ServiceTimeoutError(ServiceError): -# """Service returned 504""" - -# pass - - -# class FailedStatusCodeError(ServiceError): -# """Service returned unknown status code""" - -# pass - - -# status_code_to_error = { -# 400: ArgumentError, -# 401: UnauthorizedError, -# 403: QuotaExceededError, -# 404: NotFoundError, -# 408: DeviceTimeoutError, -# 409: DeviceAlreadyExistsError, -# 412: InvalidEtagError, -# 413: MessageTooLargeError, -# 429: ThrottlingError, -# 500: InternalServiceError, -# 502: BadDeviceResponseError, -# 503: ServiceUnavailableError, -# 504: ServiceTimeoutError, -# } - - -# def error_from_status_code(status_code, message=None): -# """ -# Return an Error object from a failed status code - -# :param int status_code: Status code returned from failed operation -# :returns: Error object -# """ -# if status_code in status_code_to_error: -# return status_code_to_error[status_code](message) -# else: -# return FailedStatusCodeError(message) diff --git a/azure-iot-device/azure/iot/device/http_path_iothub.py b/azure-iot-device/azure/iot/device/http_path_iothub.py new file mode 100644 index 000000000..31225655b --- /dev/null +++ b/azure-iot-device/azure/iot/device/http_path_iothub.py @@ -0,0 +1,45 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import urllib +from typing import Optional + +logger = logging.getLogger(__name__) + + +def get_direct_method_invoke_path(device_id: str, module_id: Optional[str] = None) -> str: + """ + :return: The relative path for invoking methods from one module to a device or module. It is of the format + /twins/uri_encode($device_id)/modules/uri_encode($module_id)/methods + """ + if module_id: + return "/twins/{device_id}/modules/{module_id}/methods".format( + device_id=urllib.parse.quote_plus(device_id), + module_id=urllib.parse.quote_plus(module_id), + ) + else: + return "/twins/{device_id}/methods".format(device_id=urllib.parse.quote_plus(device_id)) + + +def get_storage_info_for_blob_path(device_id: str): + """ + This does not take a module_id since get_storage_info_for_blob_path should only ever be invoked on device clients. + + :return: The relative path for getting the storage sdk credential information from IoT Hub. It is of the format + /devices/uri_encode($device_id)/files + """ + return "/devices/{}/files".format(urllib.parse.quote_plus(device_id)) + + +def get_notify_blob_upload_status_path(device_id: str): + """ + This does not take a module_id since get_notify_blob_upload_status_path should only ever be invoked on device clients. + + :return: The relative path for getting the storage sdk credential information from IoT Hub. It is of the format + /devices/uri_encode($device_id)/files/notifications + """ + return "/devices/{}/files/notifications".format(urllib.parse.quote_plus(device_id)) diff --git a/azure-iot-device/azure/iot/device/iothub/__init__.py b/azure-iot-device/azure/iot/device/iothub/__init__.py deleted file mode 100644 index f81aea8b3..000000000 --- a/azure-iot-device/azure/iot/device/iothub/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Azure IoT Hub Device Library - -This library provides functionality for communicating with the Azure IoT Hub -as a Device or Module. -""" - -from .sync_clients import IoTHubDeviceClient, IoTHubModuleClient -from .models import Message, MethodRequest, MethodResponse - -__all__ = ["IoTHubDeviceClient", "IoTHubModuleClient", "Message", "MethodRequest", "MethodResponse"] diff --git a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py b/azure-iot-device/azure/iot/device/iothub/abstract_clients.py deleted file mode 100644 index ed712e298..000000000 --- a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py +++ /dev/null @@ -1,911 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains abstract classes for the various clients of the Azure IoT Hub Device SDK -""" - -import abc -import logging -import threading -import os -import io -import time -from . import pipeline -from .pipeline import constant as pipeline_constant -from azure.iot.device.common.auth import connection_string as cs -from azure.iot.device.common.auth import sastoken as st -from azure.iot.device.iothub import client_event -from azure.iot.device import exceptions -from azure.iot.device.common import auth, handle_exceptions -from . import edge_hsm - -logger = logging.getLogger(__name__) - - -def _validate_kwargs(exclude=[], **kwargs): - """Helper function to validate user provided kwargs. - Raises TypeError if an invalid option has been provided""" - valid_kwargs = [ - "server_verification_cert", - "gateway_hostname", - "websockets", - "cipher", - "product_info", - "proxy_options", - "sastoken_ttl", - "keep_alive", - "auto_connect", - "connection_retry", - "connection_retry_interval", - "ensure_desired_properties", - ] - - for kwarg in kwargs: - if (kwarg not in valid_kwargs) or (kwarg in exclude): - raise TypeError("Unsupported keyword argument: '{}'".format(kwarg)) - - -def _get_config_kwargs(**kwargs): - """Get the subset of kwargs which pertain the config object""" - valid_config_kwargs = [ - "server_verification_cert", - "gateway_hostname", - "websockets", - "cipher", - "product_info", - "proxy_options", - "keep_alive", - "auto_connect", - "connection_retry", - "connection_retry_interval", - "ensure_desired_properties", - ] - - config_kwargs = {} - for kwarg in kwargs: - if kwarg in valid_config_kwargs: - config_kwargs[kwarg] = kwargs[kwarg] - return config_kwargs - - -def _form_sas_uri(hostname, device_id, module_id=None): - if module_id: - return "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=hostname, device_id=device_id, module_id=module_id - ) - else: - return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) - - -def _extract_sas_uri_values(uri): - d = {} - items = uri.split("/") - if len(items) != 3 and len(items) != 5: - raise ValueError("Invalid SAS URI") - if items[1] != "devices": - raise ValueError("Cannot extract device id from SAS URI") - if len(items) > 3 and items[3] != "modules": - raise ValueError("Cannot extract module id from SAS URI") - d["hostname"] = items[0] - d["device_id"] = items[2] - try: - d["module_id"] = items[4] - except IndexError: - d["module_id"] = None - return d - - -# Receive Type constant defs -RECEIVE_TYPE_NONE_SET = "none_set" # Type of receiving has not been set -RECEIVE_TYPE_HANDLER = "handler" # Only use handlers for receive -RECEIVE_TYPE_API = "api" # Only use APIs for receive - - -class AbstractIoTHubClient(abc.ABC): - """A superclass representing a generic IoTHub client. - This class needs to be extended for specific clients. - """ - - def __init__(self, mqtt_pipeline, http_pipeline): - """Initializer for a generic client. - - :param mqtt_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type mqtt_pipeline: :class:`azure.iot.device.iothub.pipeline.MQTTPipeline` - """ - self._mqtt_pipeline = mqtt_pipeline - self._http_pipeline = http_pipeline - - self._inbox_manager = None # this will be overridden in child class - self._handler_manager = None # this will be overridden in child class - self._receive_type = RECEIVE_TYPE_NONE_SET - self._client_lock = threading.Lock() - - def _on_connected(self): - """Helper handler that is called upon an iothub pipeline connect""" - logger.info("Connection State - Connected") - client_event_inbox = self._inbox_manager.get_client_event_inbox() - # Only add a ClientEvent to the inbox if the Handler Manager is capable of dealing with it - if self._handler_manager.handling_client_events: - event = client_event.ClientEvent(client_event.CONNECTION_STATE_CHANGE) - client_event_inbox.put(event) - # Ensure that all handlers are running now that connection is re-established. - self._handler_manager.ensure_running() - - def _on_disconnected(self): - """Helper handler that is called upon an iothub pipeline disconnect""" - logger.info("Connection State - Disconnected") - client_event_inbox = self._inbox_manager.get_client_event_inbox() - # Only add a ClientEvent to the inbox if the Handler Manager is capable of dealing with it - if self._handler_manager.handling_client_events: - event = client_event.ClientEvent(client_event.CONNECTION_STATE_CHANGE) - client_event_inbox.put(event) - # Locally stored method requests on client are cleared. - # They will be resent by IoTHub on reconnect. - self._inbox_manager.clear_all_method_requests() - logger.info("Cleared all pending method requests due to disconnect") - - def _on_new_sastoken_required(self): - """Helper handler that is called upon the iothub pipeline needing new SAS token""" - logger.info("New SasToken required from user") - client_event_inbox = self._inbox_manager.get_client_event_inbox() - # Only add a ClientEvent to the inbox if the Handler Manager is capable of dealing with it - if self._handler_manager.handling_client_events: - event = client_event.ClientEvent(client_event.NEW_SASTOKEN_REQUIRED) - client_event_inbox.put(event) - - def _on_background_exception(self, e): - """Helper handler that is called upon an iothub pipeline background exception""" - handle_exceptions.handle_background_exception(e) - client_event_inbox = self._inbox_manager.get_client_event_inbox() - # Only add a ClientEvent to the inbox if the Handler Manager is capable of dealing with it - if self._handler_manager.handling_client_events: - event = client_event.ClientEvent(client_event.BACKGROUND_EXCEPTION, e) - client_event_inbox.put(event) - - def _check_receive_mode_is_api(self): - """Call this function first in EVERY receive API""" - with self._client_lock: - if self._receive_type is RECEIVE_TYPE_NONE_SET: - # Lock the client to ONLY use receive APIs (no handlers) - self._receive_type = RECEIVE_TYPE_API - elif self._receive_type is RECEIVE_TYPE_HANDLER: - raise exceptions.ClientError( - "Cannot use receive APIs - receive handler(s) have already been set" - ) - else: - pass - - def _check_receive_mode_is_handler(self): - """Call this function first in EVERY handler setter""" - with self._client_lock: - if self._receive_type is RECEIVE_TYPE_NONE_SET: - # Lock the client to ONLY use receive handlers (no APIs) - self._receive_type = RECEIVE_TYPE_HANDLER - # Set the inbox manager to use unified msg receives - self._inbox_manager.use_unified_msg_mode = True - elif self._receive_type is RECEIVE_TYPE_API: - raise exceptions.ClientError( - "Cannot set receive handlers - receive APIs have already been used" - ) - else: - pass - - def _replace_user_supplied_sastoken(self, sastoken_str): - """ - Replaces the pipeline's NonRenewableSasToken with a new one based on a provided - sastoken string. Also does validation. - This helper only updates the PipelineConfig - it does not reauthorize the connection. - """ - if not isinstance( - self._mqtt_pipeline.pipeline_configuration.sastoken, st.NonRenewableSasToken - ): - raise exceptions.ClientError( - "Cannot update sastoken when client was not created with one" - ) - # Create new SasToken - try: - new_token_o = st.NonRenewableSasToken(sastoken_str) - except st.SasTokenError as e: - new_err = ValueError("Invalid SasToken provided") - new_err.__cause__ = e - raise new_err - # Extract values from SasToken - vals = _extract_sas_uri_values(new_token_o.resource_uri) - # Validate new token - if type(self).__name__ == "IoTHubDeviceClient" and vals["module_id"]: - raise ValueError("Provided SasToken is for a module") - if type(self).__name__ == "IoTHubModuleClient" and not vals["module_id"]: - raise ValueError("Provided SasToken is for a device") - if self._mqtt_pipeline.pipeline_configuration.device_id != vals["device_id"]: - raise ValueError("Provided SasToken does not match existing device id") - if self._mqtt_pipeline.pipeline_configuration.module_id != vals["module_id"]: - raise ValueError("Provided SasToken does not match existing module id") - if self._mqtt_pipeline.pipeline_configuration.hostname != vals["hostname"]: - raise ValueError("Provided SasToken does not match existing hostname") - if new_token_o.expiry_time < int(time.time()): - raise ValueError("Provided SasToken has already expired") - # Set token - # NOTE: We only need to set this on MQTT because this is a reference to the same object - # that is stored in HTTP. The HTTP pipeline is updated implicitly. - self._mqtt_pipeline.pipeline_configuration.sastoken = new_token_o - - @abc.abstractmethod - def _generic_receive_handler_setter(self, handler_name, feature_name, new_handler): - # Will be implemented differently in child classes, but define here for static analysis - pass - - @classmethod - def create_from_connection_string(cls, connection_string, **kwargs): - """ - Instantiate the client from a IoTHub device or module connection string. - - :param str connection_string: The connection string for the IoTHub you wish to connect to. - - :param str server_verification_cert: Configuration Option. The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param str product_info: Configuration Option. Default is empty string. The string contains - arbitrary product info which is appended to the user agent string. - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int sastoken_ttl: The time to live (in seconds) for the created SasToken used for - authentication. Default is 3600 seconds (1 hour). - :param int keep_alive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :param bool auto_connect: Automatically connect the client to IoTHub when a method is - invoked which requires a connection to be established. (Default: True) - :param bool connection_retry: Attempt to re-establish a dropped connection (Default: True) - :param int connection_retry_interval: Interval, in seconds, between attempts to - re-establish a dropped connection (Default: 10) - :param bool ensure_desired_properties: Ensure the most recent desired properties patch has - been received upon re-connections (Default:True) - - :raises: ValueError if given an invalid connection_string. - :raises: TypeError if given an unsupported parameter. - - :returns: An instance of an IoTHub client that uses a connection string for authentication. - """ - # TODO: Make this device/module specific and reject non-matching connection strings. - - # Ensure no invalid kwargs were passed by the user - excluded_kwargs = ["gateway_hostname"] - _validate_kwargs(exclude=excluded_kwargs, **kwargs) - - # Create SasToken - connection_string = cs.ConnectionString(connection_string) - if connection_string.get(cs.X509) is not None: - raise ValueError( - "Use the .create_from_x509_certificate() method instead when using X509 certificates" - ) - uri = _form_sas_uri( - hostname=connection_string[cs.HOST_NAME], - device_id=connection_string[cs.DEVICE_ID], - module_id=connection_string.get(cs.MODULE_ID), - ) - signing_mechanism = auth.SymmetricKeySigningMechanism( - key=connection_string[cs.SHARED_ACCESS_KEY] - ) - token_ttl = kwargs.get("sastoken_ttl", 3600) - try: - sastoken = st.RenewableSasToken(uri, signing_mechanism, ttl=token_ttl) - except st.SasTokenError as e: - new_err = ValueError("Could not create a SasToken using provided values") - new_err.__cause__ = e - raise new_err - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=connection_string[cs.DEVICE_ID], - module_id=connection_string.get(cs.MODULE_ID), - hostname=connection_string[cs.HOST_NAME], - gateway_hostname=connection_string.get(cs.GATEWAY_HOST_NAME), - sastoken=sastoken, - **config_kwargs - ) - if cls.__name__ == "IoTHubDeviceClient": - pipeline_configuration.blob_upload = True - - # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - - return cls(mqtt_pipeline, http_pipeline) - - @classmethod - def create_from_sastoken(cls, sastoken, **kwargs): - """Instantiate the client from a pre-created SAS Token string - - :param str sastoken: The SAS Token string - - :param str server_verification_cert: Configuration Option. The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param str gateway_hostname: Configuration Option. The gateway hostname for the gateway - device. - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param str product_info: Configuration Option. Default is empty string. The string contains - arbitrary product info which is appended to the user agent string. - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int keep_alive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :param bool auto_connect: Automatically connect the client to IoTHub when a method is - invoked which requires a connection to be established. (Default: True) - :param bool connection_retry: Attempt to re-establish a dropped connection (Default: True) - :param int connection_retry_interval: Interval, in seconds, between attempts to - re-establish a dropped connection (Default: 10) - :param bool ensure_desired_properties: Ensure the most recent desired properties patch has - been received upon re-connections (Default:True) - - :raises: TypeError if given an unsupported parameter. - :raises: ValueError if the sastoken parameter is invalid. - """ - # Ensure no invalid kwargs were passed by the user - excluded_kwargs = ["sastoken_ttl"] - _validate_kwargs(exclude=excluded_kwargs, **kwargs) - - # Create SasToken object from string - try: - sastoken_o = st.NonRenewableSasToken(sastoken) - except st.SasTokenError as e: - new_err = ValueError("Invalid SasToken provided") - new_err.__cause__ = e - raise new_err - # Extract values from SasToken - vals = _extract_sas_uri_values(sastoken_o.resource_uri) - if cls.__name__ == "IoTHubDeviceClient" and vals["module_id"]: - raise ValueError("Provided SasToken is for a module") - if cls.__name__ == "IoTHubModuleClient" and not vals["module_id"]: - raise ValueError("Provided SasToken is for a device") - if sastoken_o.expiry_time < int(time.time()): - raise ValueError("Provided SasToken has already expired") - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=vals["device_id"], - module_id=vals["module_id"], - hostname=vals["hostname"], - sastoken=sastoken_o, - **config_kwargs - ) - if cls.__name__ == "IoTHubDeviceClient": - pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients - - # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - - return cls(mqtt_pipeline, http_pipeline) - - @abc.abstractmethod - def shutdown(self): - pass - - @abc.abstractmethod - def connect(self): - pass - - @abc.abstractmethod - def disconnect(self): - pass - - @abc.abstractmethod - def update_sastoken(self, sastoken): - pass - - @abc.abstractmethod - def send_message(self, message): - pass - - @abc.abstractmethod - def receive_method_request(self, method_name=None): - pass - - @abc.abstractmethod - def send_method_response(self, method_request, payload, status): - pass - - @abc.abstractmethod - def get_twin(self): - pass - - @abc.abstractmethod - def patch_twin_reported_properties(self, reported_properties_patch): - pass - - @abc.abstractmethod - def receive_twin_desired_properties_patch(self): - pass - - @property - def connected(self): - """ - Read-only property to indicate if the transport is connected or not. - """ - return self._mqtt_pipeline.connected - - @property - def on_connection_state_change(self): - """The handler function or coroutine that will be called when the connection state changes. - - The function or coroutine definition should take no positional arguments. - """ - return self._handler_manager.on_connection_state_change - - @on_connection_state_change.setter - def on_connection_state_change(self, value): - self._handler_manager.on_connection_state_change = value - - @property - def on_new_sastoken_required(self): - """The handler function or coroutine that will be called when the client requires a new - SAS token. This will happen approximately 2 minutes before the SAS Token expires. - On Windows platforms, if the lifespan exceeds approximately 49 days, a new token will - be required after those 49 days regardless of how long the SAS lifespan is. - - Note that this handler is ONLY necessary when using a client created via the - .create_from_sastoken() method. - - The new token can be provided in your function or coroutine via use of the client's - .update_sastoken() method. - - The function or coroutine definition should take no positional arguments. - """ - return self._handler_manager.on_new_sastoken_required - - @on_new_sastoken_required.setter - def on_new_sastoken_required(self, value): - self._handler_manager.on_new_sastoken_required = value - - @property - def on_background_exception(self): - """The handler function or coroutine will be called when a background exception occurs. - - The function or coroutine definition should take one positional argument (the exception - object)""" - return self._handler_manager.on_background_exception - - @on_background_exception.setter - def on_background_exception(self, value): - self._handler_manager.on_background_exception = value - - @abc.abstractproperty - def on_message_received(self): - # Defined below on AbstractIoTHubDeviceClient / AbstractIoTHubModuleClient - pass - - @property - def on_method_request_received(self): - """The handler function or coroutine that will be called when a method request is received. - - Remember to acknowledge the method request in your function or coroutine via use of the - client's .send_method_response() method. - - The function or coroutine definition should take one positional argument (the - :class:`azure.iot.device.MethodRequest` object)""" - return self._handler_manager.on_method_request_received - - @on_method_request_received.setter - def on_method_request_received(self, value): - self._generic_receive_handler_setter( - "on_method_request_received", pipeline_constant.METHODS, value - ) - - @property - def on_twin_desired_properties_patch_received(self): - """The handler function or coroutine that will be called when a twin desired properties - patch is received. - - The function or coroutine definition should take one positional argument (the twin patch - in the form of a JSON dictionary object)""" - return self._handler_manager.on_twin_desired_properties_patch_received - - @on_twin_desired_properties_patch_received.setter - def on_twin_desired_properties_patch_received(self, value): - self._generic_receive_handler_setter( - "on_twin_desired_properties_patch_received", pipeline_constant.TWIN_PATCHES, value - ) - - -class AbstractIoTHubDeviceClient(AbstractIoTHubClient): - @classmethod - def create_from_x509_certificate(cls, x509, hostname, device_id, **kwargs): - """ - Instantiate a client using X509 certificate authentication. - - :param str hostname: Host running the IotHub. - Can be found in the Azure portal in the Overview tab as the string hostname. - :param x509: The complete x509 certificate object. - To use the certificate the enrollment object needs to contain cert - (either the root certificate or one of the intermediate CA certificates). - If the cert comes from a CER file, it needs to be base64 encoded. - :type x509: :class:`azure.iot.device.X509` - :param str device_id: The ID used to uniquely identify a device in the IoTHub - - :param str server_verification_cert: Configuration Option. The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param str gateway_hostname: Configuration Option. The gateway hostname for the gateway - device. - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param str product_info: Configuration Option. Default is empty string. The string contains - arbitrary product info which is appended to the user agent string. - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int keep_alive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :param bool auto_connect: Automatically connect the client to IoTHub when a method is - invoked which requires a connection to be established. (Default: True) - :param bool connection_retry: Attempt to re-establish a dropped connection (Default: True) - :param int connection_retry_interval: Interval, in seconds, between attempts to - re-establish a dropped connection (Default: 10) - :param bool ensure_desired_properties: Ensure the most recent desired properties patch has - been received upon re-connections (Default:True) - - :raises: TypeError if given an unsupported parameter. - - :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. - """ - # Ensure no invalid kwargs were passed by the user - excluded_kwargs = ["sastoken_ttl"] - _validate_kwargs(exclude=excluded_kwargs, **kwargs) - - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=device_id, hostname=hostname, x509=x509, **config_kwargs - ) - pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients - pipeline_configuration.ensure_desired_properties = True - - # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - - return cls(mqtt_pipeline, http_pipeline) - - @classmethod - def create_from_symmetric_key(cls, symmetric_key, hostname, device_id, **kwargs): - """ - Instantiate a client using symmetric key authentication. - - :param symmetric_key: The symmetric key. - :param str hostname: Host running the IotHub. - Can be found in the Azure portal in the Overview tab as the string hostname. - :param device_id: The device ID - - :param str server_verification_cert: Configuration Option. The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param str gateway_hostname: Configuration Option. The gateway hostname for the gateway - device. - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param str product_info: Configuration Option. Default is empty string. The string contains - arbitrary product info which is appended to the user agent string. - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int sastoken_ttl: The time to live (in seconds) for the created SasToken used for - authentication. Default is 3600 seconds (1 hour) - :param int keep_alive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :param bool auto_connect: Automatically connect the client to IoTHub when a method is - invoked which requires a connection to be established. (Default: True) - :param bool connection_retry: Attempt to re-establish a dropped connection (Default: True) - :param int connection_retry_interval: Interval, in seconds, between attempts to - re-establish a dropped connection (Default: 10) - :param bool ensure_desired_properties: Ensure the most recent desired properties patch has - been received upon re-connections (Default:True) - - :raises: TypeError if given an unsupported parameter. - :raises: ValueError if the provided parameters are invalid. - - :return: An instance of an IoTHub client that uses a symmetric key for authentication. - """ - # Ensure no invalid kwargs were passed by the user - _validate_kwargs(**kwargs) - - # Create SasToken - uri = _form_sas_uri(hostname=hostname, device_id=device_id) - signing_mechanism = auth.SymmetricKeySigningMechanism(key=symmetric_key) - token_ttl = kwargs.get("sastoken_ttl", 3600) - try: - sastoken = st.RenewableSasToken(uri, signing_mechanism, ttl=token_ttl) - except st.SasTokenError as e: - new_err = ValueError("Could not create a SasToken using provided values") - new_err.__cause__ = e - raise new_err - - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=device_id, hostname=hostname, sastoken=sastoken, **config_kwargs - ) - pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients - pipeline_configuration.ensure_desired_properties = True - - # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - - return cls(mqtt_pipeline, http_pipeline) - - @abc.abstractmethod - def receive_message(self): - pass - - @abc.abstractmethod - def get_storage_info_for_blob(self, blob_name): - pass - - @abc.abstractmethod - def notify_blob_upload_status( - self, correlation_id, is_success, status_code, status_description - ): - pass - - @property - def on_message_received(self): - """The handler function or coroutine that will be called when a message is received. - - The function or coroutine definition should take one positional argument (the - :class:`azure.iot.device.Message` object)""" - return self._handler_manager.on_message_received - - @on_message_received.setter - def on_message_received(self, value): - self._generic_receive_handler_setter( - "on_message_received", pipeline_constant.C2D_MSG, value - ) - - -class AbstractIoTHubModuleClient(AbstractIoTHubClient): - @classmethod - def create_from_edge_environment(cls, **kwargs): - """ - Instantiate the client from the IoT Edge environment. - - This method can only be run from inside an IoT Edge container, or in a debugging - environment configured for Edge development (e.g. Visual Studio, Visual Studio Code) - - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param str product_info: Configuration Option. Default is empty string. The string contains - arbitrary product info which is appended to the user agent string. - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int sastoken_ttl: The time to live (in seconds) for the created SasToken used for - authentication. Default is 3600 seconds (1 hour) - :param int keep_alive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :param bool auto_connect: Automatically connect the client to IoTHub when a method is - invoked which requires a connection to be established. (Default: True) - :param bool connection_retry: Attempt to re-establish a dropped connection (Default: True) - :param int connection_retry_interval: Interval, in seconds, between attempts to - re-establish a dropped connection (Default: 10) - - :raises: OSError if the IoT Edge container is not configured correctly. - :raises: ValueError if debug variables are invalid. - :raises: TypeError if given an unsupported parameter. - - :returns: An instance of an IoTHub client that uses the IoT Edge environment for - authentication. - """ - # Ensure no invalid kwargs were passed by the user - excluded_kwargs = ["server_verification_cert", "gateway_hostname"] - _validate_kwargs(exclude=excluded_kwargs, **kwargs) - - # First try the regular Edge container variables - try: - hostname = os.environ["IOTEDGE_IOTHUBHOSTNAME"] - device_id = os.environ["IOTEDGE_DEVICEID"] - module_id = os.environ["IOTEDGE_MODULEID"] - gateway_hostname = os.environ["IOTEDGE_GATEWAYHOSTNAME"] - module_generation_id = os.environ["IOTEDGE_MODULEGENERATIONID"] - workload_uri = os.environ["IOTEDGE_WORKLOADURI"] - api_version = os.environ["IOTEDGE_APIVERSION"] - except KeyError: - # As a fallback, try the Edge local dev variables for debugging. - # These variables are set by VS/VS Code in order to allow debugging - # of Edge application code in a non-Edge dev environment. - try: - connection_string = os.environ["EdgeHubConnectionString"] - ca_cert_filepath = os.environ["EdgeModuleCACertificateFile"] - except KeyError as e: - new_err = OSError("IoT Edge environment not configured correctly") - new_err.__cause__ = e - raise new_err - - # Read the certificate file to pass it on as a string - # TODO: variant server_verification_cert file vs data object that would remove the need for this file open - try: - with io.open(ca_cert_filepath, mode="r") as ca_cert_file: - server_verification_cert = ca_cert_file.read() - except FileNotFoundError: - raise - except OSError as e: - raise ValueError("Invalid CA certificate file") from e - - # Extract config values from connection string - connection_string = cs.ConnectionString(connection_string) - try: - device_id = connection_string[cs.DEVICE_ID] - module_id = connection_string[cs.MODULE_ID] - hostname = connection_string[cs.HOST_NAME] - gateway_hostname = connection_string[cs.GATEWAY_HOST_NAME] - except KeyError: - raise ValueError("Invalid Connection String") - - # Use Symmetric Key authentication for local dev experience. - signing_mechanism = auth.SymmetricKeySigningMechanism( - key=connection_string[cs.SHARED_ACCESS_KEY] - ) - - else: - # Use an HSM for authentication in the general case - hsm = edge_hsm.IoTEdgeHsm( - module_id=module_id, - generation_id=module_generation_id, - workload_uri=workload_uri, - api_version=api_version, - ) - try: - server_verification_cert = hsm.get_certificate() - except edge_hsm.IoTEdgeError as e: - new_err = OSError("Unexpected failure in IoTEdge") - new_err.__cause__ = e - raise new_err - signing_mechanism = hsm - - # Create SasToken - uri = _form_sas_uri(hostname=hostname, device_id=device_id, module_id=module_id) - token_ttl = kwargs.get("sastoken_ttl", 3600) - try: - sastoken = st.RenewableSasToken(uri, signing_mechanism, ttl=token_ttl) - except st.SasTokenError as e: - new_err = ValueError( - "Could not create a SasToken using the values provided, or in the Edge environment" - ) - new_err.__cause__ = e - raise new_err - - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=device_id, - module_id=module_id, - hostname=hostname, - gateway_hostname=gateway_hostname, - sastoken=sastoken, - server_verification_cert=server_verification_cert, - **config_kwargs - ) - pipeline_configuration.ensure_desired_properties = True - - pipeline_configuration.method_invoke = ( - True # Method Invoke is allowed on modules created from edge environment - ) - - # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - - return cls(mqtt_pipeline, http_pipeline) - - @classmethod - def create_from_x509_certificate(cls, x509, hostname, device_id, module_id, **kwargs): - """ - Instantiate a client using X509 certificate authentication. - - :param str hostname: Host running the IotHub. - Can be found in the Azure portal in the Overview tab as the string hostname. - :param x509: The complete x509 certificate object. - To use the certificate the enrollment object needs to contain cert - (either the root certificate or one of the intermediate CA certificates). - If the cert comes from a CER file, it needs to be base64 encoded. - :type x509: :class:`azure.iot.device.X509` - :param str device_id: The ID used to uniquely identify a device in the IoTHub - :param str module_id: The ID used to uniquely identify a module on a device on the IoTHub. - - :param str server_verification_cert: Configuration Option. The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param str gateway_hostname: Configuration Option. The gateway hostname for the gateway - device. - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param str product_info: Configuration Option. Default is empty string. The string contains - arbitrary product info which is appended to the user agent string. - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int keep_alive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :param bool auto_connect: Automatically connect the client to IoTHub when a method is - invoked which requires a connection to be established. (Default: True) - :param bool connection_retry: Attempt to re-establish a dropped connection (Default: True) - :param int connection_retry_interval: Interval, in seconds, between attempts to - re-establish a dropped connection (Default: 10) - :param bool ensure_desired_properties: Ensure the most recent desired properties patch has - been received upon re-connections (Default:True) - - :raises: TypeError if given an unsupported parameter. - - :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. - """ - # Ensure no invalid kwargs were passed by the user - excluded_kwargs = ["sastoken_ttl"] - _validate_kwargs(exclude=excluded_kwargs, **kwargs) - - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=device_id, module_id=module_id, hostname=hostname, x509=x509, **config_kwargs - ) - pipeline_configuration.ensure_desired_properties = True - - # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - return cls(mqtt_pipeline, http_pipeline) - - @abc.abstractmethod - def send_message_to_output(self, message, output_name): - pass - - @abc.abstractmethod - def receive_message_on_input(self, input_name): - pass - - @abc.abstractmethod - def invoke_method(self, method_params, device_id, module_id=None): - pass - - @property - def on_message_received(self): - """The handler function or coroutine that will be called when an input message is received. - - The function definition or coroutine should take one positional argument (the - :class:`azure.iot.device.Message` object)""" - return self._handler_manager.on_message_received - - @on_message_received.setter - def on_message_received(self, value): - self._generic_receive_handler_setter( - "on_message_received", pipeline_constant.INPUT_MSG, value - ) diff --git a/azure-iot-device/azure/iot/device/iothub/aio/__init__.py b/azure-iot-device/azure/iot/device/iothub/aio/__init__.py deleted file mode 100644 index 1af0fb8af..000000000 --- a/azure-iot-device/azure/iot/device/iothub/aio/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Azure IoT Hub Device SDK - Asynchronous - -This SDK provides asynchronous functionality for communicating with the Azure IoT Hub -as a Device or Module. -""" - -from .async_clients import IoTHubDeviceClient, IoTHubModuleClient - -__all__ = ["IoTHubDeviceClient", "IoTHubModuleClient"] diff --git a/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py b/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py deleted file mode 100644 index c57f3ec14..000000000 --- a/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py +++ /dev/null @@ -1,709 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains user-facing asynchronous clients for the -Azure IoTHub Device SDK for Python. -""" - -import logging -import asyncio -import deprecation -from azure.iot.device.common import async_adapter -from azure.iot.device.iothub.abstract_clients import ( - AbstractIoTHubClient, - AbstractIoTHubDeviceClient, - AbstractIoTHubModuleClient, -) -from azure.iot.device.iothub.models import Message -from azure.iot.device.iothub.pipeline import constant -from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions -from azure.iot.device import exceptions -from azure.iot.device.iothub.inbox_manager import InboxManager -from .async_inbox import AsyncClientInbox -from . import async_handler_manager, loop_management -from azure.iot.device import constant as device_constant - -logger = logging.getLogger(__name__) - - -async def handle_result(callback): - try: - return await callback.completion() - except pipeline_exceptions.ConnectionDroppedError as e: - raise exceptions.ConnectionDroppedError("Lost connection to IoTHub") from e - except pipeline_exceptions.ConnectionFailedError as e: - raise exceptions.ConnectionFailedError("Could not connect to IoTHub") from e - except pipeline_exceptions.NoConnectionError as e: - raise exceptions.NoConnectionError("Client is not connected to IoTHub") from e - except pipeline_exceptions.UnauthorizedError as e: - raise exceptions.CredentialError("Credentials invalid, could not connect") from e - except pipeline_exceptions.ProtocolClientError as e: - raise exceptions.ClientError("Error in the IoTHub client") from e - except pipeline_exceptions.TlsExchangeAuthError as e: - raise exceptions.ClientError("Error in the IoTHub client due to TLS exchanges.") from e - except pipeline_exceptions.ProtocolProxyError as e: - raise exceptions.ClientError( - "Error in the IoTHub client raised due to proxy connections." - ) from e - except pipeline_exceptions.PipelineNotRunning as e: - raise exceptions.ClientError("Client has already been shut down") from e - except pipeline_exceptions.OperationCancelled as e: - raise exceptions.OperationCancelled("Operation was cancelled before completion") from e - except pipeline_exceptions.OperationTimeout as e: - raise exceptions.OperationTimeout("Could not complete operation before timeout") from e - except Exception as e: - raise exceptions.ClientError("Unexpected failure") from e - - -class GenericIoTHubClient(AbstractIoTHubClient): - """A super class representing a generic asynchronous client. - This class needs to be extended for specific clients. - """ - - def __init__(self, **kwargs): - """Initializer for a generic asynchronous client. - - This initializer should not be called directly. - Instead, use one of the 'create_from_' class methods to instantiate - - :param mqtt_pipeline: The MQTTPipeline used for the client - :type mqtt_pipeline: :class:`azure.iot.device.iothub.pipeline.MQTTPipeline` - :param http_pipeline: The HTTPPipeline used for the client - :type http_pipeline: :class:`azure.iot.device.iothub.pipeline.HTTPPipeline` - """ - # Depending on the subclass calling this __init__, there could be different arguments, - # and the super() call could call a different class, due to the different MROs - # in the class hierarchies of different clients. Thus, args here must be passed along as - # **kwargs. - super().__init__(**kwargs) - self._inbox_manager = InboxManager(inbox_type=AsyncClientInbox) - self._handler_manager = async_handler_manager.AsyncHandlerManager(self._inbox_manager) - - # Set pipeline handlers for client events - self._mqtt_pipeline.on_connected = self._on_connected - self._mqtt_pipeline.on_disconnected = self._on_disconnected - self._mqtt_pipeline.on_new_sastoken_required = self._on_new_sastoken_required - self._mqtt_pipeline.on_background_exception = self._on_background_exception - - # Set pipeline handlers for data receives - self._mqtt_pipeline.on_method_request_received = self._inbox_manager.route_method_request - self._mqtt_pipeline.on_twin_patch_received = self._inbox_manager.route_twin_patch - - async def _enable_feature(self, feature_name): - """Enable an Azure IoT Hub feature - - :param feature_name: The name of the feature to enable. - See azure.iot.device.common.pipeline.constant for possible values. - """ - logger.info("Enabling feature:" + feature_name + "...") - if not self._mqtt_pipeline.feature_enabled[feature_name]: - # Enable the feature if not already enabled - enable_feature_async = async_adapter.emulate_async(self._mqtt_pipeline.enable_feature) - - callback = async_adapter.AwaitableCallback() - await enable_feature_async(feature_name, callback=callback) - await handle_result(callback) - - logger.info("Successfully enabled feature:" + feature_name) - else: - # This branch shouldn't be reached, but in case it is, log it - logger.info("Feature ({}) already enabled - skipping".format(feature_name)) - - async def _disable_feature(self, feature_name): - """Disable an Azure IoT Hub feature - - :param feature_name: The name of the feature to enable. - See azure.iot.device.common.pipeline.constant for possible values. - """ - logger.info("Disabling feature: {}...".format(feature_name)) - if self._mqtt_pipeline.feature_enabled[feature_name]: - # Disable the feature if not already disabled - disable_feature_async = async_adapter.emulate_async(self._mqtt_pipeline.disable_feature) - - callback = async_adapter.AwaitableCallback() - await disable_feature_async(feature_name, callback=callback) - await handle_result(callback) - - logger.info("Successfully disabled feature: {}".format(feature_name)) - else: - # This branch shouldn't be reached, but in case it is, log it - logger.info("Feature ({}) already disabled - skipping".format(feature_name)) - - def _generic_receive_handler_setter(self, handler_name, feature_name, new_handler): - """Set a receive handler on the handler manager and enable the corresponding feature. - - This is a synchronous call (yes, even though this is the async client), meaning that this - function will not return until the feature has been enabled (if necessary). - - :param str handler_name: The name of the handler on the handler manager to set - :param str feature_name: The name of the pipeline feature that corresponds to the handler - :param new_handler: The function to be set as the handler - """ - self._check_receive_mode_is_handler() - # Set the handler on the handler manager - setattr(self._handler_manager, handler_name, new_handler) - - # Enable the feature if necessary - if new_handler is not None and not self._mqtt_pipeline.feature_enabled[feature_name]: - # We have to call this on a loop running on a different thread in order to ensure - # the setter can be called both within a coroutine (with a running event loop) and - # outside of a coroutine (where no event loop is currently running) - loop = loop_management.get_client_internal_loop() - fut = asyncio.run_coroutine_threadsafe(self._enable_feature(feature_name), loop=loop) - fut.result() - - # Disable the feature if necessary - elif new_handler is None and self._mqtt_pipeline.feature_enabled[feature_name]: - # We have to call this on a loop running on a different thread in order to ensure - # the setter can be called both within a coroutine (with a running event loop) and - # outside of a coroutine (where no event loop is currently running) - loop = loop_management.get_client_internal_loop() - fut = asyncio.run_coroutine_threadsafe(self._disable_feature(feature_name), loop=loop) - fut.result() - - async def shutdown(self): - """Shut down the client for graceful exit. - - Once this method is called, any attempts at further client calls will result in a - ClientError being raised - - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Initiating client shutdown") - # Note that client disconnect does the following: - # - Disconnects the pipeline - # - Resolves all pending receiver handler calls - # - Stops receiver handler threads - await self.disconnect() - - # Note that shutting down does the following: - # - Disconnects the MQTT pipeline - # - Stops MQTT pipeline threads - logger.debug("Beginning pipeline shutdown operation") - shutdown_async = async_adapter.emulate_async(self._mqtt_pipeline.shutdown) - callback = async_adapter.AwaitableCallback() - await shutdown_async(callback=callback) - await handle_result(callback) - logger.debug("Completed pipeline shutdown operation") - - # Stop the Client Event handlers now that everything else is completed - self._handler_manager.stop(receiver_handlers_only=False) - - # Yes, that means the pipeline is disconnected twice (well, actually three times if you - # consider that the client-level disconnect causes two pipeline-level disconnects for - # reasons explained in comments in the client's .disconnect() method). - # - # This last disconnect that occurs as a result of the pipeline shutdown is a bit different - # from the first though, in that it's more "final" and can't simply just be reconnected. - - # Note also that only the MQTT pipeline is shut down. The reason is twofold: - # 1. There are no known issues related to graceful exit if the HTTP pipeline is not - # explicitly shut down - # 2. The HTTP pipeline is planned for eventual removal from the client - # In light of these two facts, it seemed irrelevant to spend time implementing shutdown - # capability for HTTP pipeline. - logger.info("Client shutdown complete") - - async def connect(self): - """Connects the client to an Azure IoT Hub or Azure IoT Edge Hub instance. - - The destination is chosen based on the credentials passed via the auth_provider parameter - that was provided when this object was initialized. - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if the connection times out. - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Connecting to Hub...") - connect_async = async_adapter.emulate_async(self._mqtt_pipeline.connect) - - callback = async_adapter.AwaitableCallback() - await connect_async(callback=callback) - await handle_result(callback) - - logger.info("Successfully connected to Hub") - - async def disconnect(self): - """Disconnect the client from the Azure IoT Hub or Azure IoT Edge Hub instance. - - It is recommended that you make sure to call this coroutine when you are completely done - with the your client instance. - - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Disconnecting from Hub...") - - logger.debug("Executing initial disconnect") - disconnect_async = async_adapter.emulate_async(self._mqtt_pipeline.disconnect) - callback = async_adapter.AwaitableCallback() - await disconnect_async(callback=callback) - await handle_result(callback) - logger.debug("Successfully executed initial disconnect") - - # Note that in the process of stopping the handlers and resolving pending calls - # a user-supplied handler may cause a reconnection to occur - logger.debug("Stopping handlers...") - self._handler_manager.stop(receiver_handlers_only=True) - logger.debug("Successfully stopped handlers") - - # Disconnect again to ensure disconnection has occurred due to the issue mentioned above - logger.debug("Executing secondary disconnect...") - disconnect_async = async_adapter.emulate_async(self._mqtt_pipeline.disconnect) - callback = async_adapter.AwaitableCallback() - await disconnect_async(callback=callback) - await handle_result(callback) - logger.debug("Successfully executed secondary disconnect") - - # It's also possible that in the (very short) time between stopping the handlers and - # the second disconnect, additional items were received (e.g. C2D Message) - # Currently, this isn't really possible to accurately check due to a - # race condition / thread timing issue with inboxes where we can't guarantee how many - # items are truly in them. - # It has always been true of this client, even before handlers. - # - # However, even if the race condition is addressed, that will only allow us to log that - # messages were lost. To actually fix the problem, IoTHub needs to support MQTT5 so that - # we can unsubscribe from receiving data. - - logger.info("Successfully disconnected from Hub") - - async def update_sastoken(self, sastoken): - """ - Update the client's SAS Token used for authentication, then reauthorizes the connection. - - This API can only be used if the client was initially created with a SAS Token. - - :param str sastoken: The new SAS Token string for the client to use - - :raises: ValueError if the sastoken parameter is invalid - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be re-established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a re-establishing - the connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if the reauthorization - attempt times out. - :raises: :class:`azure.iot.device.exceptions.ClientError` if the client was not initially - created with a SAS token. - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - self._replace_user_supplied_sastoken(sastoken) - - # Reauthorize the connection - logger.info("Reauthorizing connection with Hub...") - reauth_connection_async = async_adapter.emulate_async( - self._mqtt_pipeline.reauthorize_connection - ) - callback = async_adapter.AwaitableCallback() - await reauth_connection_async(callback=callback) - await handle_result(callback) - # NOTE: Currently due to the MQTT3 implementation, the pipeline reauthorization will return - # after the disconnect. It does not wait for the reconnect to complete. This means that - # any errors that may occur as part of the connect will not return via this callback. - # They will instead go to the background exception handler. - - logger.info("Successfully reauthorized connection to Hub") - - async def send_message(self, message): - """Sends a message to the default events endpoint on the Azure IoT Hub or Azure IoT Edge Hub instance. - - If the connection to the service has not previously been opened by a call to connect, this - function will open the connection before sending the event. - - :param message: The actual message to send. Anything passed that is not an instance of the - Message class will be converted to Message object. - :type message: :class:`azure.iot.device.Message` or str - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - :raises: ValueError if the message fails size validation. - """ - if not isinstance(message, Message): - message = Message(message) - - if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - raise ValueError("Size of telemetry message can not exceed 256 KB.") - - logger.info("Sending message to Hub...") - send_message_async = async_adapter.emulate_async(self._mqtt_pipeline.send_message) - - callback = async_adapter.AwaitableCallback() - await send_message_async(message, callback=callback) - await handle_result(callback) - - logger.info("Successfully sent message to Hub") - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_method_request_received property to set a handler instead", - ) - async def receive_method_request(self, method_name=None): - """Receive a method request via the Azure IoT Hub or Azure IoT Edge Hub. - - If no method request is yet available, will wait until it is available. - - :param str method_name: Optionally provide the name of the method to receive requests for. - If this parameter is not given, all methods not already being specifically targeted by - a different call to receive_method will be received. - - :returns: MethodRequest object representing the received method request. - :rtype: :class:`azure.iot.device.MethodRequest` - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[constant.METHODS]: - await self._enable_feature(constant.METHODS) - - method_inbox = self._inbox_manager.get_method_request_inbox(method_name) - - logger.info("Waiting for method request...") - method_request = await method_inbox.get() - logger.info("Received method request") - return method_request - - async def send_method_response(self, method_response): - """Send a response to a method request via the Azure IoT Hub or Azure IoT Edge Hub. - - If the connection to the service has not previously been opened by a call to connect, this - function will open the connection before sending the event. - - :param method_response: The MethodResponse to send - :type method_response: :class:`azure.iot.device.MethodResponse` - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Sending method response to Hub...") - send_method_response_async = async_adapter.emulate_async( - self._mqtt_pipeline.send_method_response - ) - - callback = async_adapter.AwaitableCallback() - - # TODO: maybe consolidate method_request, result and status into a new object - await send_method_response_async(method_response, callback=callback) - await handle_result(callback) - - logger.info("Successfully sent method response to Hub") - - async def get_twin(self): - """ - Gets the device or module twin from the Azure IoT Hub or Azure IoT Edge Hub service. - - :returns: Complete Twin as a JSON dict - :rtype: dict - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Getting twin") - - if not self._mqtt_pipeline.feature_enabled[constant.TWIN]: - await self._enable_feature(constant.TWIN) - - get_twin_async = async_adapter.emulate_async(self._mqtt_pipeline.get_twin) - - callback = async_adapter.AwaitableCallback(return_arg_name="twin") - await get_twin_async(callback=callback) - twin = await handle_result(callback) - logger.info("Successfully retrieved twin") - return twin - - async def patch_twin_reported_properties(self, reported_properties_patch): - """ - Update reported properties with the Azure IoT Hub or Azure IoT Edge Hub service. - - If the service returns an error on the patch operation, this function will raise the - appropriate error. - - :param reported_properties_patch: Twin Reported Properties patch as a JSON dict - :type reported_properties_patch: dict - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Patching twin reported properties") - - if not self._mqtt_pipeline.feature_enabled[constant.TWIN]: - await self._enable_feature(constant.TWIN) - - patch_twin_async = async_adapter.emulate_async( - self._mqtt_pipeline.patch_twin_reported_properties - ) - - callback = async_adapter.AwaitableCallback() - await patch_twin_async(patch=reported_properties_patch, callback=callback) - await handle_result(callback) - - logger.info("Successfully sent twin patch") - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_twin_desired_properties_patch_received property to set a handler instead", - ) - async def receive_twin_desired_properties_patch(self): - """ - Receive a desired property patch via the Azure IoT Hub or Azure IoT Edge Hub. - - If no method request is yet available, will wait until it is available. - - :returns: Twin Desired Properties patch as a JSON dict - :rtype: dict - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[constant.TWIN_PATCHES]: - await self._enable_feature(constant.TWIN_PATCHES) - twin_patch_inbox = self._inbox_manager.get_twin_patch_inbox() - - logger.info("Waiting for twin patches...") - patch = await twin_patch_inbox.get() - logger.info("twin patch received") - return patch - - -class IoTHubDeviceClient(GenericIoTHubClient, AbstractIoTHubDeviceClient): - """An asynchronous device client that connects to an Azure IoT Hub instance.""" - - def __init__(self, mqtt_pipeline, http_pipeline): - """Initializer for a IoTHubDeviceClient. - - This initializer should not be called directly. - Instead, use one of the 'create_from_' classmethods to instantiate - - :param mqtt_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type mqtt_pipeline: :class:`azure.iot.device.iothub.pipeline.MQTTPipeline` - """ - super().__init__(mqtt_pipeline=mqtt_pipeline, http_pipeline=http_pipeline) - self._mqtt_pipeline.on_c2d_message_received = self._inbox_manager.route_c2d_message - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_message_received property to set a handler instead", - ) - async def receive_message(self): - """Receive a message that has been sent from the Azure IoT Hub. - - If no message is yet available, will wait until an item is available. - - :returns: Message that was sent from the Azure IoT Hub. - :rtype: :class:`azure.iot.device.Message` - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[constant.C2D_MSG]: - await self._enable_feature(constant.C2D_MSG) - c2d_inbox = self._inbox_manager.get_c2d_message_inbox() - - logger.info("Waiting for message from Hub...") - message = await c2d_inbox.get() - logger.info("Message received") - return message - - async def get_storage_info_for_blob(self, blob_name): - """Sends a POST request over HTTP to an IoTHub endpoint that will return information for uploading via the Azure Storage Account linked to the IoTHub your device is connected to. - - :param str blob_name: The name in string format of the blob that will be uploaded using the storage API. This name will be used to generate the proper credentials for Storage, and needs to match what will be used with the Azure Storage SDK to perform the blob upload. - - :returns: A JSON-like (dictionary) object from IoT Hub that will contain relevant information including: correlationId, hostName, containerName, blobName, sasToken. - """ - get_storage_info_for_blob_async = async_adapter.emulate_async( - self._http_pipeline.get_storage_info_for_blob - ) - - callback = async_adapter.AwaitableCallback(return_arg_name="storage_info") - await get_storage_info_for_blob_async(blob_name=blob_name, callback=callback) - storage_info = await handle_result(callback) - logger.info("Successfully retrieved storage_info") - return storage_info - - async def notify_blob_upload_status( - self, correlation_id, is_success, status_code, status_description - ): - """When the upload is complete, the device sends a POST request to the IoT Hub endpoint with information on the status of an upload to blob attempt. This is used by IoT Hub to notify listening clients. - - :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. - :param bool is_success: A boolean that indicates whether the file was uploaded successfully. - :param int status_code: A numeric status code that is the status for the upload of the file to storage. - :param str status_description: A description that corresponds to the status_code. - """ - notify_blob_upload_status_async = async_adapter.emulate_async( - self._http_pipeline.notify_blob_upload_status - ) - - callback = async_adapter.AwaitableCallback() - await notify_blob_upload_status_async( - correlation_id=correlation_id, - is_success=is_success, - status_code=status_code, - status_description=status_description, - callback=callback, - ) - await handle_result(callback) - logger.info("Successfully notified blob upload status") - - -class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): - """An asynchronous module client that connects to an Azure IoT Hub or Azure IoT Edge instance.""" - - def __init__(self, mqtt_pipeline, http_pipeline): - """Initializer for a IoTHubModuleClient. - - This initializer should not be called directly. - Instead, use one of the 'create_from_' class methods to instantiate - - :param mqtt_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type mqtt_pipeline: :class:`azure.iot.device.iothub.pipeline.MQTTPipeline` - """ - super().__init__(mqtt_pipeline=mqtt_pipeline, http_pipeline=http_pipeline) - self._mqtt_pipeline.on_input_message_received = self._inbox_manager.route_input_message - - async def send_message_to_output(self, message, output_name): - """Sends an event/message to the given module output. - - These are outgoing events and are meant to be "output events" - - If the connection to the service has not previously been opened by a call to connect, this - function will open the connection before sending the event. - - :param message: Message to send to the given output. Anything passed that is not an - instance of the Message class will be converted to Message object. - :type message: :class:`azure.iot.device.Message` or str - :param str output_name: Name of the output to send the event to. - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - :raises: ValueError if the message fails size validation. - """ - if not isinstance(message, Message): - message = Message(message) - - if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - raise ValueError("Size of message can not exceed 256 KB.") - - message.output_name = output_name - - logger.info("Sending message to output:" + output_name + "...") - send_output_message_async = async_adapter.emulate_async( - self._mqtt_pipeline.send_output_message - ) - - callback = async_adapter.AwaitableCallback() - await send_output_message_async(message, callback=callback) - await handle_result(callback) - - logger.info("Successfully sent message to output: " + output_name) - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_message_received property to set a handler instead", - ) - async def receive_message_on_input(self, input_name): - """Receive an input message that has been sent from another Module to a specific input. - - If no message is yet available, will wait until an item is available. - - :param str input_name: The input name to receive a message on. - - :returns: Message that was sent to the specified input. - :rtype: :class:`azure.iot.device.Message` - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[constant.INPUT_MSG]: - await self._enable_feature(constant.INPUT_MSG) - inbox = self._inbox_manager.get_input_message_inbox(input_name) - - logger.info("Waiting for input message on: " + input_name + "...") - message = await inbox.get() - logger.info("Input message received on: " + input_name) - return message - - async def invoke_method(self, method_params, device_id, module_id=None): - """Invoke a method from your client onto a device or module client, and receive the response to the method call. - - :param dict method_params: Should contain a methodName (str), payload (str), - connectTimeoutInSeconds (int), responseTimeoutInSeconds (int). - :param str device_id: Device ID of the target device where the method will be invoked. - :param str module_id: Module ID of the target module where the method will be invoked. (Optional) - - :returns: method_result should contain a status, and a payload - :rtype: dict - """ - logger.info( - "Invoking {} method on {}{}".format(method_params["methodName"], device_id, module_id) - ) - - invoke_method_async = async_adapter.emulate_async(self._http_pipeline.invoke_method) - callback = async_adapter.AwaitableCallback(return_arg_name="invoke_method_response") - await invoke_method_async(device_id, method_params, callback=callback, module_id=module_id) - - method_response = await handle_result(callback) - logger.info("Successfully invoked method") - return method_response diff --git a/azure-iot-device/azure/iot/device/iothub/aio/async_handler_manager.py b/azure-iot-device/azure/iot/device/iothub/aio/async_handler_manager.py deleted file mode 100644 index ec2a51322..000000000 --- a/azure-iot-device/azure/iot/device/iothub/aio/async_handler_manager.py +++ /dev/null @@ -1,246 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" This module contains the manager for handler methods used by the aio clients""" - -import asyncio -import logging -import inspect -import concurrent.futures -from azure.iot.device.common import handle_exceptions -from azure.iot.device.iothub.sync_handler_manager import ( - AbstractHandlerManager, - HandlerManagerException, - HandlerRunnerKillerSentinel, - CLIENT_EVENT, -) -from . import loop_management - -logger = logging.getLogger(__name__) - - -class AsyncHandlerManager(AbstractHandlerManager): - """Handler manager for use with asynchronous clients""" - - async def _receiver_handler_runner(self, inbox, handler_name): - """Run infinite loop that waits for an inbox to receive an object from it, then calls - the handler with that object - """ - logger.debug("HANDLER RUNNER ({}): Starting runner".format(handler_name)) - - # Define a callback that can handle errors in the ThreadPoolExecutor - _handler_callback = self._generate_callback_for_handler("CLIENT_EVENT") - - # ThreadPool used for running handler functions. By invoking handlers in a separate thread - # we can be safe knowing that customer code that has performance issues does not block - # client code. Note that the ThreadPool is only used for handler FUNCTIONS (coroutines are - # invoked on a dedicated event loop + thread) - tpe = concurrent.futures.ThreadPoolExecutor(max_workers=4) - while True: - handler_arg = await inbox.get() - if isinstance(handler_arg, HandlerRunnerKillerSentinel): - # Exit the runner when a HandlerRunnerKillerSentinel is found - logger.debug( - "HANDLER RUNNER ({}): HandlerRunnerKillerSentinel found in inbox. Exiting.".format( - handler_name - ) - ) - tpe.shutdown() - break - # NOTE: we MUST use getattr here using the handler name, as opposed to directly passing - # the handler in order for the handler to be able to be updated without cancelling - # the running task created for this coroutine - handler = getattr(self, handler_name) - logger.debug("HANDLER RUNNER ({}): Invoking handler".format(handler_name)) - if inspect.iscoroutinefunction(handler): - # Run coroutine on a dedicated event loop for handler invocations - # TODO: Can we call this on the user loop instead? - handler_loop = loop_management.get_client_handler_loop() - fut = asyncio.run_coroutine_threadsafe(handler(handler_arg), handler_loop) - # Free up this object so the garbage collector can free it if necessary. If we don't - # do this, we end up keeping this object alive until the next event arrives, which - # might be a long time. Tests would flag this as a memory leak if that happened. - del handler_arg - fut.add_done_callback(_handler_callback) - else: - # Run function directly in ThreadPool - fut = tpe.submit(handler, handler_arg) - # Free up this object so the garbage collector can free it if necessary. If we don't - # do this, we end up keeping this object alive until the next event arrives, which - # might be a long time. Tests would flag this as a memory leak if that happened. - del handler_arg - fut.add_done_callback(_handler_callback) - - async def _client_event_handler_runner(self): - """Run infinite loop that waits for the client event inbox to receive an event from it, - then calls the handler that corresponds to that event - """ - logger.debug("HANDLER RUNNER (CLIENT EVENT): Starting runner") - _handler_callback = self._generate_callback_for_handler("CLIENT_EVENT") - - # ThreadPool used for running handler functions. By invoking handlers in a separate thread - # we can be safe knowing that customer code that has performance issues does not block - # client code. Note that the ThreadPool is only used for handler FUNCTIONS (coroutines are - # invoked on a dedicated event loop + thread) - tpe = concurrent.futures.ThreadPoolExecutor(max_workers=4) - event_inbox = self._inbox_manager.get_client_event_inbox() - while True: - event = await event_inbox.get() - if isinstance(event, HandlerRunnerKillerSentinel): - # Exit the runner when a HandlerRunnerKillerSentinel is found - logger.debug( - "HANDLER RUNNER (CLIENT EVENT): HandlerRunnerKillerSentinel found in event queue. Exiting." - ) - tpe.shutdown() - break - handler = self._get_handler_for_client_event(event.name) - if handler is not None: - logger.debug( - "HANDLER RUNNER (CLIENT EVENT): {} event received. Invoking {} handler".format( - event, handler - ) - ) - if inspect.iscoroutinefunction(handler): - # Run a coroutine on a dedicated event loop for handler invocations - # TODO: Can we call this on the user loop instead? - handler_loop = loop_management.get_client_handler_loop() - fut = asyncio.run_coroutine_threadsafe( - handler(*event.args_for_user), handler_loop - ) - # Free up this object so the garbage collector can free it if necessary. If we don't - # do this, we end up keeping this object alive until the next event arrives, which - # might be a long time. Tests would flag this as a memory leak if that happened. - del event - fut.add_done_callback(_handler_callback) - else: - # Run a function directly in ThreadPool - fut = tpe.submit(handler, *event.args_for_user) - # Free up this object so the garbage collector can free it if necessary. If we don't - # do this, we end up keeping this object alive until the next event arrives, which - # might be a long time. Tests would flag this as a memory leak if that happened. - del event - fut.add_done_callback(_handler_callback) - else: - logger.debug( - "No handler for event {} set. Skipping handler invocation".format(event) - ) - - def _start_handler_runner(self, handler_name): - """Create, and store a task for running a handler""" - # Run the handler runner on a dedicated event loop for handler runners so as to be - # isolated from all other client activities - runner_loop = loop_management.get_client_handler_runner_loop() - - # Client Event handler flow - if handler_name == CLIENT_EVENT: - if self._client_event_runner is not None: - # This branch of code should NOT be reachable due to checks prior to the invocation - # of this method. The branch exists for safety. - raise HandlerManagerException( - "Cannot create thread for handler runner: {}. Runner thread already exists".format( - handler_name - ) - ) - # Client events share a handler - coro = self._client_event_handler_runner() - future = asyncio.run_coroutine_threadsafe(coro, runner_loop) - # Store the future - self._client_event_runner = future - - # Receiver handler flow - else: - if self._receiver_handler_runners[handler_name] is not None: - # This branch of code should NOT be reachable due to checks prior to the invocation - # of this method. The branch exists for safety. - raise HandlerManagerException( - "Cannot create task for handler runner: {}. Task already exists".format( - handler_name - ) - ) - # Each receiver handler gets its own runner - inbox = self._get_inbox_for_receive_handler(handler_name) - coro = self._receiver_handler_runner(inbox, handler_name) - future = asyncio.run_coroutine_threadsafe(coro, runner_loop) - # Store the future - self._receiver_handler_runners[handler_name] = future - - _handler_runner_callback = self._generate_callback_for_handler_runner(handler_name) - future.add_done_callback(_handler_runner_callback) - - def _stop_receiver_handler_runner(self, handler_name): - """Stop and remove a handler runner task. - All pending items in the corresponding inbox will be handled by the handler before stoppage. - """ - logger.debug( - "Adding HandlerRunnerKillerSentinel to inbox corresponding to {} handler runner".format( - handler_name - ) - ) - inbox = self._get_inbox_for_receive_handler(handler_name) - inbox.put(HandlerRunnerKillerSentinel()) - - # Wait for Handler Runner to end due to the sentinel - logger.debug("Waiting for {} handler runner to exit...".format(handler_name)) - future = self._receiver_handler_runners[handler_name] - future.result() - # Stop tracking the task since it is now complete - self._receiver_handler_runners[handler_name] = None - logger.debug("Handler runner for {} has been stopped".format(handler_name)) - - def _stop_client_event_handler_runner(self): - """Stop and remove a handler task. - All pending items in the client event queue will be handled by handlers (if they exist) - before stoppage. - """ - logger.debug("Adding HandlerRunnerKillerSentinel to client event queue") - event_inbox = self._inbox_manager.get_client_event_inbox() - event_inbox.put(HandlerRunnerKillerSentinel()) - - # Wait for Handler Runner to end due to the stop command - logger.debug("Waiting for client event handler runner to exit...") - future = self._client_event_runner - future.result() - # Stop tracking the task since it is now complete - self._client_event_runner = None - logger.debug("Handler runner for client events has been stopped") - - def _generate_callback_for_handler_runner(self, handler_name): - """Define a callback that can handle errors during handler runner execution""" - - def handler_runner_callback(completed_future): - try: - e = completed_future.exception(timeout=0) - except Exception as raised_e: - # This shouldn't happen because cancellation or timeout shouldn't occur... - # But just in case... - new_err = HandlerManagerException( - "HANDLER RUNNER ({}): Unable to retrieve exception data from incomplete task".format( - handler_name - ) - ) - new_err.__cause__ = raised_e - handle_exceptions.handle_background_exception(new_err) - else: - if e: - # If this branch is reached something has gone SERIOUSLY wrong. - # We must log the error, and then restart the runner so that the program - # does not enter an invalid state - new_err = HandlerManagerException( - "HANDLER RUNNER ({}): Unexpected error during task".format(handler_name), - ) - new_err.__cause__ = e - handle_exceptions.handle_background_exception(new_err) - # Clear the tracked runner, and start a new one - logger.debug("HANDLER RUNNER ({}): Restarting handler runner") - self._receiver_handler_runners[handler_name] = None - self._start_handler_runner(handler_name) - else: - logger.debug( - "HANDLER RUNNER ({}): Task successfully completed without exception".format( - handler_name - ) - ) - - return handler_runner_callback diff --git a/azure-iot-device/azure/iot/device/iothub/aio/async_inbox.py b/azure-iot-device/azure/iot/device/iothub/aio/async_inbox.py deleted file mode 100644 index ec09892a0..000000000 --- a/azure-iot-device/azure/iot/device/iothub/aio/async_inbox.py +++ /dev/null @@ -1,82 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains an Inbox class for use with an asynchronous client""" -import asyncio -import janus -from azure.iot.device.iothub.sync_inbox import AbstractInbox -from . import loop_management - -# IMPLEMENTATION NOTE: The janus Queue exists entirely on the "client internal loop", -# which runs on its own thread. Think of it kind of as a worker loop where async inbox access -# operations are scheduled, with the results returned back to whatever thread/loop scheduled them. -# We do this so that it is safe to use inboxes across different threads, in different places. -# (e.g. customer thread, handler manager thread, callback thread, etc.) - - -class AsyncClientInbox(AbstractInbox): - """Holds generic incoming data for an asynchronous client. - - All methods implemented in this class are threadsafe. - """ - - def __init__(self): - """Initializer for AsyncClientInbox.""" - - # The queue must be instantiated on the client internal loop, but there's no way to do - # that at instantiation from a different loop, so instead we make coroutine to do the - # task and run it on the client internal loop. - # It's not pretty, but it works (newer versions of janus have a loop parameter, but - # not the version we are currently locked at) - async def make_queue(): - return janus.Queue() - - loop = loop_management.get_client_internal_loop() - fut = asyncio.run_coroutine_threadsafe(make_queue(), loop) - self._queue = fut.result() - - def __contains__(self, item): - """Return True if item is in Inbox, False otherwise""" - # Note that this function accesses private attributes of janus, thus it is somewhat - # dangerous. Unfortunately, it is the only way to implement this functionality. - # However, because this function is only used in tests, I feel it is acceptable. - with self._queue._sync_mutex: - return item in self._queue._queue - - def put(self, item): - """Put an item into the Inbox. - - :param item: The item to be put in the Inbox. - """ - self._queue.sync_q.put(item) - - async def get(self): - """Remove and return an item from the Inbox. - - If Inbox is empty, wait until an item is available. - - :returns: An item from the Inbox. - """ - loop = loop_management.get_client_internal_loop() - fut = asyncio.run_coroutine_threadsafe(self._queue.async_q.get(), loop) - return await asyncio.wrap_future(fut) - - def empty(self): - """Returns True if the inbox is empty, False otherwise - - Note that there is a race condition here, and this may not be accurate. This is because - the .empty() operation on a janus queue is not threadsafe. - - :returns: Boolean indicating if the inbox is empty - """ - return self._queue.async_q.empty() - - def clear(self): - """Remove all items from the inbox.""" - while True: - try: - self._queue.sync_q.get_nowait() - except janus.SyncQueueEmpty: - break diff --git a/azure-iot-device/azure/iot/device/iothub/aio/loop_management.py b/azure-iot-device/azure/iot/device/iothub/aio/loop_management.py deleted file mode 100644 index f2b73a9c2..000000000 --- a/azure-iot-device/azure/iot/device/iothub/aio/loop_management.py +++ /dev/null @@ -1,67 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" This module contains functions of managing event loops for the IoTHub client -""" -import asyncio -import threading -import logging - -logger = logging.getLogger(__name__) - -loops = { - "CLIENT_HANDLER_LOOP": None, - "CLIENT_INTERNAL_LOOP": None, - "CLIENT_HANDLER_RUNNER_LOOP": None, -} - - -def _cleanup(): - """Clear all running loops and end respective threads. - ONLY FOR TESTING USAGE - By using this function, you can wipe all global loops. - DO NOT USE THIS IN PRODUCTION CODE - """ - for loop_name, loop in loops.items(): - if loop is not None: - logger.debug("Stopping event loop - {}".format(loop_name)) - loop.call_soon_threadsafe(loop.stop) - # NOTE: Stopping the loop will also end the thread, because the only thing keeping - # the thread alive was the loop running - loops[loop_name] = None - - -def _make_new_loop(loop_name): - logger.debug("Creating new event loop - {}".format(loop_name)) - # Create the loop on a new Thread - new_loop = asyncio.new_event_loop() - loop_thread = threading.Thread(target=new_loop.run_forever) - # Make the Thread a daemon so it will not block program exit - loop_thread.daemon = True - loop_thread.start() - # Store the loop - loops[loop_name] = new_loop - - -def get_client_internal_loop(): - """Return the loop for internal client operations""" - if loops["CLIENT_INTERNAL_LOOP"] is None: - _make_new_loop("CLIENT_INTERNAL_LOOP") - return loops["CLIENT_INTERNAL_LOOP"] - - -def get_client_handler_runner_loop(): - """Return the loop for handler runners""" - if loops["CLIENT_HANDLER_RUNNER_LOOP"] is None: - _make_new_loop("CLIENT_HANDLER_RUNNER_LOOP") - return loops["CLIENT_HANDLER_RUNNER_LOOP"] - - -def get_client_handler_loop(): - """Return the loop for invoking user-provided handlers on the client""" - # TODO: Try and store the user loop somehow - if loops["CLIENT_HANDLER_LOOP"] is None: - _make_new_loop("CLIENT_HANDLER_LOOP") - return loops["CLIENT_HANDLER_LOOP"] diff --git a/azure-iot-device/azure/iot/device/iothub/client_event.py b/azure-iot-device/azure/iot/device/iothub/client_event.py deleted file mode 100644 index 11f56cd04..000000000 --- a/azure-iot-device/azure/iot/device/iothub/client_event.py +++ /dev/null @@ -1,24 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -NEW_SASTOKEN_REQUIRED = "NEW_SASTOKEN_REQUIRED" -CONNECTION_STATE_CHANGE = "CONNECTION_STATE_CHANGE" -BACKGROUND_EXCEPTION = "BACKGROUND_EXCEPTION" - - -class ClientEvent(object): - """Represents an event that has occurred within the client. - - Contains a name for the event, as well as any associated values that should be provided to the - user when the event occurs. - - Note that a "Client Event" represents an event generated by the client itself, rather than an - unsolicited data receive event (those have a different process) - """ - - def __init__(self, name, *args_for_user): - self.name = name - self.args_for_user = args_for_user diff --git a/azure-iot-device/azure/iot/device/iothub/inbox_manager.py b/azure-iot-device/azure/iot/device/iothub/inbox_manager.py deleted file mode 100644 index 24056e28f..000000000 --- a/azure-iot-device/azure/iot/device/iothub/inbox_manager.py +++ /dev/null @@ -1,190 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains a manager for inboxes.""" - -import logging - -logger = logging.getLogger(__name__) - - -class InboxManager(object): - """Manages the various Inboxes for a client. - - :ivar c2d_message_inbox: The C2D message Inbox. - :ivar input_message_inboxes: A dictionary mapping input names to input message Inboxes. - :ivar generic_method_request_inbox: The generic method request Inbox. - :ivar named_method_request_inboxes: A dictionary mapping method names to method request Inboxes. - """ - - def __init__(self, inbox_type): - """Initializer for the InboxManager. - - :param inbox_type: An Inbox class that the manager will use to create Inboxes. - """ - self._create_inbox = inbox_type - self.unified_message_inbox = self._create_inbox() - self.generic_method_request_inbox = self._create_inbox() - self.twin_patch_inbox = self._create_inbox() - self.client_event_inbox = self._create_inbox() - - # These inboxes are used only for non-unified receives, using APIs which are now - # deprecated on the client. However we need to keep them functional for backwards - # compatibility - self.c2d_message_inbox = self._create_inbox() - self.input_message_inboxes = {} - self.named_method_request_inboxes = {} - - # Set this value to True if want to only use unified message mode - self.use_unified_msg_mode = False - - def get_unified_message_inbox(self): - """Retrieve the Inbox for all messages (C2D and Input)""" - return self.unified_message_inbox - - def get_input_message_inbox(self, input_name): - """Retrieve the input message Inbox for a given input. - - If the Inbox does not already exist, it will be created. - - :param str input_name: The name of the input for which the associated Inbox is desired. - :returns: An Inbox for input messages on the selected input. - """ - try: - inbox = self.input_message_inboxes[input_name] - except KeyError: - # Create new Inbox for input if it does not yet exist - inbox = self._create_inbox() - self.input_message_inboxes[input_name] = inbox - - return inbox - - def get_c2d_message_inbox(self): - """Retrieve the Inbox for C2D messages. - - :returns: An Inbox for C2D messages. - """ - return self.c2d_message_inbox - - def get_method_request_inbox(self, method_name=None): - """Retrieve the method request Inbox for a given method name if provided, - or for generic method requests if not. - - If the Inbox does not already exist, it will be created. - - :param str method_name: Optional. The name of the method for which the - associated Inbox is desired. - :returns: An Inbox for method requests. - """ - if method_name: - try: - inbox = self.named_method_request_inboxes[method_name] - except KeyError: - # Create a new Inbox for the method name - inbox = self._create_inbox() - self.named_method_request_inboxes[method_name] = inbox - else: - inbox = self.generic_method_request_inbox - - return inbox - - def get_twin_patch_inbox(self): - """Retrieve the Inbox for twin patches that arrive from the service - - :returns: An Inbox for twin patches - """ - return self.twin_patch_inbox - - def get_client_event_inbox(self): - """Retrieve the Inbox for events that occur within the client - - :returns: An Inbox for client events - """ - return self.client_event_inbox - - def clear_all_method_requests(self): - """Delete all method requests currently in inboxes.""" - self.generic_method_request_inbox.clear() - for inbox in self.named_method_request_inboxes.values(): - inbox.clear() - - def route_input_message(self, incoming_message): - """Route an incoming input message - - In unified message mode, route to the unified message inbox - - In standard mode, route to the corresponding input message Inbox. If the input - is unknown, the message will be dropped. - - :param incoming_message: The message to be routed. - - :returns: Boolean indicating if message was successfully routed or not. - """ - input_name = incoming_message.input_name - if self.use_unified_msg_mode: - # Put in the unified message inbox if in simplified mode - self.unified_message_inbox.put(incoming_message) - return True - else: - # If not in simplified mode, get a specific inbox for the input - try: - inbox = self.input_message_inboxes[input_name] - except KeyError: - logger.warning( - "No input message inbox for {} - dropping message".format(input_name) - ) - return False - else: - inbox.put(incoming_message) - logger.debug("Input message sent to {} inbox".format(input_name)) - return True - - def route_c2d_message(self, incoming_message): - """Route an incoming C2D message - - In unified message mode, route to the unified message inbox. - - In standard mode, route to to the C2D message Inbox. - - :param incoming_message: The message to be routed. - - :returns: Boolean indicating if message was successfully routed or not. - """ - if self.use_unified_msg_mode: - # Put in the unified message inbox if in simplified mode - self.unified_message_inbox.put(incoming_message) - return True - else: - self.c2d_message_inbox.put(incoming_message) - logger.debug("C2D message sent to inbox") - return True - - def route_method_request(self, incoming_method_request): - """Route an incoming method request to the correct method request Inbox. - - If the method name is recognized, it will be routed to a method-specific Inbox. - Otherwise, it will be routed to the generic method request Inbox. - - :param incoming_method_request: The method request to be routed. - - :returns: Boolean indicating if the method request was successfully routed or not. - """ - try: - inbox = self.named_method_request_inboxes[incoming_method_request.name] - except KeyError: - inbox = self.generic_method_request_inbox - inbox.put(incoming_method_request) - return True - - def route_twin_patch(self, incoming_patch): - """Route an incoming twin patch to the twin patch Inbox. - - :param incoming_patch: The patch to be routed. - - :returns: Boolean indicating if patch was successfully routed or not. - """ - self.twin_patch_inbox.put(incoming_patch) - logger.debug("twin patch message sent to inbox") - return True diff --git a/azure-iot-device/azure/iot/device/iothub/models/__init__.py b/azure-iot-device/azure/iot/device/iothub/models/__init__.py deleted file mode 100644 index e31fb8930..000000000 --- a/azure-iot-device/azure/iot/device/iothub/models/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Azure IoT Hub Device SDK Models - -This package provides object models for use within the Azure IoT Hub Device SDK. -""" - -from .message import Message # noqa: F401 -from .methods import MethodRequest, MethodResponse # noqa: F401 diff --git a/azure-iot-device/azure/iot/device/iothub/models/message.py b/azure-iot-device/azure/iot/device/iothub/models/message.py deleted file mode 100644 index 599c27502..000000000 --- a/azure-iot-device/azure/iot/device/iothub/models/message.py +++ /dev/null @@ -1,78 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains a class representing messages that are sent or received. -""" -from azure.iot.device import constant -import sys - - -class Message(object): - """Represents a message to or from IoTHub - - :ivar data: The data that constitutes the payload - :ivar custom_properties: Dictionary of custom message properties. The keys and values of these properties will always be string. - :ivar message id: A user-settable identifier for the message used for request-reply patterns. Format: A case-sensitive string (up to 128 characters long) of ASCII 7-bit alphanumeric characters + {'-', ':', '.', '+', '%', '_', '#', '*', '?', '!', '(', ')', ',', '=', '@', ';', '$', '''} - :ivar expiry_time_utc: Date and time of message expiration in UTC format - :ivar correlation_id: A property in a response message that typically contains the message_id of the request, in request-reply patterns - :ivar user_id: An ID to specify the origin of messages - :ivar content_encoding: Content encoding of the message data. Can be 'utf-8', 'utf-16' or 'utf-32' - :ivar content_type: Content type property used to route messages with the message-body. Can be 'application/json' - :ivar output_name: Name of the output that the message is being sent to. - :ivar input_name: Name of the input that the message was received on. - """ - - def __init__( - self, data, message_id=None, content_encoding=None, content_type=None, output_name=None - ): - """ - Initializer for Message - - :param data: The data that constitutes the payload - :param str message_id: A user-settable identifier for the message used for request-reply patterns. Format: A case-sensitive string (up to 128 characters long) of ASCII 7-bit alphanumeric characters + {'-', ':', '.', '+', '%', '_', '#', '*', '?', '!', '(', ')', ',', '=', '@', ';', '$', '''} - :param str content_encoding: Content encoding of the message data. Other values can be utf-16' or 'utf-32' - :param str content_type: Content type property used to routes with the message body. - :param str output_name: Name of the output that the is being sent to. - """ - self.data = data - self.custom_properties = {} - self.message_id = message_id - self.expiry_time_utc = None - self.correlation_id = None - self.user_id = None - self.content_encoding = content_encoding - self.content_type = content_type - self.output_name = output_name - self.input_name = None - self.ack = None - self._iothub_interface_id = None - - @property - def iothub_interface_id(self): - return self._iothub_interface_id - - def set_as_security_message(self): - """ - Set the message as a security message. - - This is a provisional API. Functionality not yet guaranteed. - """ - self._iothub_interface_id = constant.SECURITY_MESSAGE_INTERFACE_ID - - def __str__(self): - return str(self.data) - - def get_size(self): - total = 0 - total = total + sum( - sys.getsizeof(v) - for v in self.__dict__.values() - if v is not None and v is not self.custom_properties - ) - if self.custom_properties: - total = total + sum( - sys.getsizeof(v) for v in self.custom_properties.values() if v is not None - ) - return total diff --git a/azure-iot-device/azure/iot/device/iothub/models/methods.py b/azure-iot-device/azure/iot/device/iothub/models/methods.py deleted file mode 100644 index 16d60afaa..000000000 --- a/azure-iot-device/azure/iot/device/iothub/models/methods.py +++ /dev/null @@ -1,72 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains classes related to direct method invocations. -""" - - -class MethodRequest(object): - """Represents a request to invoke a direct method. - - :ivar str request_id: The request id. - :ivar str name: The name of the method to be invoked. - :ivar dict payload: The JSON payload being sent with the request. - """ - - def __init__(self, request_id, name, payload): - """Initializer for a MethodRequest. - - :param str request_id: The request id. - :param str name: The name of the method to be invoked - :param dict payload: The JSON payload being sent with the request. - """ - self._request_id = request_id - self._name = name - self._payload = payload - - @property - def request_id(self): - return self._request_id - - @property - def name(self): - return self._name - - @property - def payload(self): - return self._payload - - -class MethodResponse(object): - """Represents a response to a direct method. - - :ivar str request_id: The request id of the MethodRequest being responded to. - :ivar int status: The status of the execution of the MethodRequest. - :ivar payload: The JSON payload to be sent with the response. - :type payload: dict, str, int, float, bool, or None (JSON compatible values) - """ - - def __init__(self, request_id, status, payload=None): - """Initializer for MethodResponse. - - :param str request_id: The request id of the MethodRequest being responded to. - :param int status: The status of the execution of the MethodRequest. - :param payload: The JSON payload to be sent with the response. (OPTIONAL) - :type payload: dict, str, int, float, bool, or None (JSON compatible values) - """ - self.request_id = request_id - self.status = status - self.payload = payload - - @classmethod - def create_from_method_request(cls, method_request, status, payload=None): - """Factory method for creating a MethodResponse from a MethodRequest. - - :param method_request: The MethodRequest object to respond to. - :type method_request: MethodRequest. - :param int status: The status of the execution of the MethodRequest. - :type payload: dict, str, int, float, bool, or None (JSON compatible values) - """ - return cls(request_id=method_request.request_id, status=status, payload=payload) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/__init__.py b/azure-iot-device/azure/iot/device/iothub/pipeline/__init__.py deleted file mode 100644 index ac6a2b54a..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Azure IoT Hub Device SDK Pipeline - -This package provides a protocol pipeline for use with the Azure IoT Hub Device SDK. - -INTERNAL USAGE ONLY -""" - -from .mqtt_pipeline import MQTTPipeline # noqa: F401 -from .http_pipeline import HTTPPipeline # noqa: F401 -from .config import IoTHubPipelineConfig # noqa: F401 diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/config.py b/azure-iot-device/azure/iot/device/iothub/pipeline/config.py deleted file mode 100644 index ffbee399f..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/config.py +++ /dev/null @@ -1,51 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from azure.iot.device.common.pipeline.config import BasePipelineConfig - -logger = logging.getLogger(__name__) - - -class IoTHubPipelineConfig(BasePipelineConfig): - """A class for storing all configurations/options for IoTHub clients in the Azure IoT Python Device Client Library.""" - - def __init__( - self, - hostname, - device_id, - module_id=None, - product_info="", - ensure_desired_properties=True, - **kwargs - ): - """Initializer for IoTHubPipelineConfig which passes all unrecognized keyword-args down to BasePipelineConfig - to be evaluated. This stacked options setting is to allow for unique configuration options to exist between the - multiple clients, while maintaining a base configuration class with shared config options. - - :param str hostname: The hostname of the IoTHub to connect to - :param str device_id: The device identity being used with the IoTHub - :param str module_id: The module identity being used with the IoTHub - :param str product_info: A custom identification string for the type of device connecting to Azure IoT Hub. - :param bool ensure_desired_properties: Indicates if twin_patches should ensure the most - recent desired properties patch has been received upon re-connections - """ - super().__init__(hostname=hostname, **kwargs) - - # IoTHub Connection Details - self.device_id = device_id - self.module_id = module_id - - # Product Info - self.product_info = product_info - - # Stage Behavior - self.ensure_desired_properties = ensure_desired_properties - - # Now, the parameters below are not exposed to the user via kwargs. They need to be set by manipulating the IoTHubPipelineConfig object. - # They are not in the BasePipelineConfig because these do not apply to the provisioning client. - self.blob_upload = False - self.method_invoke = False diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/constant.py b/azure-iot-device/azure/iot/device/iothub/pipeline/constant.py deleted file mode 100644 index 8cc88bb06..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/constant.py +++ /dev/null @@ -1,14 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains constants related to the pipeline package. -""" - -# Feature names -C2D_MSG = "c2d" -INPUT_MSG = "input" -METHODS = "methods" -TWIN = "twin" -TWIN_PATCHES = "twin_patches" diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/exceptions.py b/azure-iot-device/azure/iot/device/iothub/pipeline/exceptions.py deleted file mode 100644 index a1f10e23f..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/exceptions.py +++ /dev/null @@ -1,23 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module defines an exception surface, exposed as part of the pipeline API""" - -# For now, present relevant transport errors as part of the Pipeline API surface -# so that they do not have to be duplicated at this layer. -from azure.iot.device.common.pipeline.pipeline_exceptions import * # noqa: F401, F403 -from azure.iot.device.common.transport_exceptions import ( # noqa: F401 - ConnectionFailedError, - ConnectionDroppedError, - NoConnectionError, - # TODO: UnauthorizedError (the one from transport) should probably not surface out of - # the pipeline due to confusion with the higher level service UnauthorizedError. It - # should probably get turned into some other error instead (e.g. ConnectionFailedError). - # But for now, this is a stopgap. - UnauthorizedError, - ProtocolClientError, - TlsExchangeAuthError, - ProtocolProxyError, -) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/http_map_error.py b/azure-iot-device/azure/iot/device/iothub/pipeline/http_map_error.py deleted file mode 100644 index b9188c0d2..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/http_map_error.py +++ /dev/null @@ -1,67 +0,0 @@ -def translate_error(sc, reason): - """ - Codes_SRS_NODE_IOTHUB_REST_API_CLIENT_16_012: [Any error object returned by translate_error shall inherit from the generic Error Javascript object and have 3 properties: - - response shall contain the IncomingMessage object returned by the HTTP layer. - - responseBody shall contain the content of the HTTP response. - - message shall contain a human-readable error message.] - """ - message = "Error: {}".format(reason) - if sc == 400: - # translate_error shall return an ArgumentError if the HTTP response status code is 400. - error = "ArgumentError({})".format(message) - - elif sc == 401: - # translate_error shall return an UnauthorizedError if the HTTP response status code is 401. - error = "UnauthorizedError({})".format(message) - - elif sc == 403: - # translate_error shall return an TooManyDevicesError if the HTTP response status code is 403. - error = "TooManyDevicesError({})".format(message) - - elif sc == 404: - if reason == "Device Not Found": - # translate_error shall return an DeviceNotFoundError if the HTTP response status code is 404 and if the error code within the body of the error response is DeviceNotFound. - error = "DeviceNotFoundError({})".format(message) - elif reason == "IoTHub Not Found": - # translate_error shall return an IotHubNotFoundError if the HTTP response status code is 404 and if the error code within the body of the error response is IotHubNotFound. - error = "IotHubNotFoundError({})".format(message) - else: - error = "Error('Not found')" - - elif sc == 408: - # translate_error shall return a DeviceTimeoutError if the HTTP response status code is 408. - error = "DeviceTimeoutError({})".format(message) - - elif sc == 409: - # translate_error shall return an DeviceAlreadyExistsError if the HTTP response status code is 409. - error = "DeviceAlreadyExistsError({})".format(message) - - elif sc == 412: - # translate_error shall return an InvalidEtagError if the HTTP response status code is 412. - error = "InvalidEtagError({})".format(message) - - elif sc == 429: - # translate_error shall return an ThrottlingError if the HTTP response status code is 429.] - error = "ThrottlingError({})".format(message) - - elif sc == 500: - # translate_error shall return an InternalServerError if the HTTP response status code is 500. - error = "InternalServerError({})".format(message) - - elif sc == 502: - # translate_error shall return a BadDeviceResponseError if the HTTP response status code is 502. - error = "BadDeviceResponseError({})".format(message) - - elif sc == 503: - # translate_error shall return an ServiceUnavailableError if the HTTP response status code is 503. - error = "ServiceUnavailableError({})".format(message) - - elif sc == 504: - # translate_error shall return a GatewayTimeoutError if the HTTP response status code is 504. - error = "GatewayTimeoutError({})".format(message) - - else: - # If the HTTP error code is unknown, translate_error should return a generic Javascript Error object. - error = "Error({})".format(message) - - return error diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/http_path_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/http_path_iothub.py deleted file mode 100644 index 7cde14c9d..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/http_path_iothub.py +++ /dev/null @@ -1,44 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import urllib - -logger = logging.getLogger(__name__) - - -def get_method_invoke_path(device_id, module_id=None): - """ - :return: The path for invoking methods from one module to a device or module. It is of the format - twins/uri_encode($device_id)/modules/uri_encode($module_id)/methods - """ - if module_id: - return "twins/{device_id}/modules/{module_id}/methods".format( - device_id=urllib.parse.quote_plus(device_id), - module_id=urllib.parse.quote_plus(module_id), - ) - else: - return "twins/{device_id}/methods".format(device_id=urllib.parse.quote_plus(device_id)) - - -def get_storage_info_for_blob_path(device_id): - """ - This does not take a module_id since get_storage_info_for_blob_path should only ever be invoked on device clients. - - :return: The path for getting the storage sdk credential information from IoT Hub. It is of the format - devices/uri_encode($device_id)/files - """ - return "devices/{}/files".format(urllib.parse.quote_plus(device_id)) - - -def get_notify_blob_upload_status_path(device_id): - """ - This does not take a module_id since get_notify_blob_upload_status_path should only ever be invoked on device clients. - - :return: The path for getting the storage sdk credential information from IoT Hub. It is of the format - devices/uri_encode($device_id)/files/notifications - """ - return "devices/{}/files/notifications".format(urllib.parse.quote_plus(device_id)) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py b/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py deleted file mode 100644 index 64257fb60..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py +++ /dev/null @@ -1,174 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from azure.iot.device.common.evented_callback import EventedCallback -from azure.iot.device.common.pipeline import ( - pipeline_nucleus, - pipeline_stages_base, - pipeline_ops_base, - pipeline_stages_http, -) - -from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions - -from . import ( - pipeline_ops_iothub_http, - pipeline_stages_iothub_http, -) - -logger = logging.getLogger(__name__) - - -class HTTPPipeline(object): - """Pipeline to communicate with Edge. - Uses HTTP. - """ - - def __init__(self, pipeline_configuration): - """ - Constructor for instantiating a pipeline adapter object. - - :param auth_provider: The authentication provider - :param pipeline_configuration: The configuration generated based on user inputs - """ - # NOTE: This pipeline DOES NOT handle SasToken management! - # (i.e. using a SasTokenStage) - # It instead relies on the parallel MQTT pipeline to handle that. - # - # Because they share a pipeline configuration, and MQTT has renewal logic we can be sure - # that the SasToken in the pipeline configuration is valid. - # - # Furthermore, because HTTP doesn't require constant connections or long running tokens, - # there's no need to reauthorize connections, so we can just pass the token from the config - # when needed for auth. - # - # This is not an ideal solution, but it's the simplest one for the time being. - - # Contains data and information shared globally within the pipeline - self._nucleus = pipeline_nucleus.PipelineNucleus(pipeline_configuration) - - self._pipeline = ( - pipeline_stages_base.PipelineRootStage(self._nucleus) - .append_stage(pipeline_stages_iothub_http.IoTHubHTTPTranslationStage()) - .append_stage(pipeline_stages_http.HTTPTransportStage()) - ) - - callback = EventedCallback() - - op = pipeline_ops_base.InitializePipelineOperation(callback=callback) - - self._pipeline.run_op(op) - callback.wait_for_completion() - - def invoke_method(self, device_id, method_params, callback, module_id=None): - """ - Send a request to the service to invoke a method on a target device or module. - - :param device_id: The target device id - :param method_params: The method parameters to be invoked on the target client - :param callback: callback which is called when request has been fulfilled. - On success, this callback is called with the error=None. - On failure, this callback is called with error set to the cause of the failure. - :param module_id: The target module id - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - """ - logger.debug("HTTPPipeline invoke_method called") - if not self._nucleus.pipeline_configuration.method_invoke: - # If this parameter is not set, that means that the pipeline was not generated by the edge environment. Method invoke only works for clients generated using the edge environment. - error = pipeline_exceptions.PipelineRuntimeError( - "invoke_method called, but it is only supported on module clients generated from an edge environment. If you are not using a module generated from an edge environment, you cannot use invoke_method" - ) - return callback(error=error) - - def on_complete(op, error): - callback(error=error, invoke_method_response=op.method_response) - - self._pipeline.run_op( - pipeline_ops_iothub_http.MethodInvokeOperation( - target_device_id=device_id, - target_module_id=module_id, - method_params=method_params, - callback=on_complete, - ) - ) - - def get_storage_info_for_blob(self, blob_name, callback): - """ - Sends a POST request to the IoT Hub service endpoint to retrieve an object that contains information for uploading via the Storage SDK. - - :param blob_name: The name of the blob that will be uploaded via the Azure Storage SDK. - :param callback: callback which is called when request has been fulfilled. - On success, this callback is called with the error=None, and the storage_info set to the information JSON received from the service. - On failure, this callback is called with error set to the cause of the failure, and the storage_info=None. - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - """ - logger.debug("HTTPPipeline get_storage_info_for_blob called") - if not self._nucleus.pipeline_configuration.blob_upload: - # If this parameter is not set, that means this is not a device client. Upload to blob is not supported on module clients. - error = pipeline_exceptions.PipelineRuntimeError( - "get_storage_info_for_blob called, but it is only supported for use with device clients. Ensure you are using a device client." - ) - return callback(error=error) - - def on_complete(op, error): - callback(error=error, storage_info=op.storage_info) - - self._pipeline.run_op( - pipeline_ops_iothub_http.GetStorageInfoOperation( - blob_name=blob_name, callback=on_complete - ) - ) - - def notify_blob_upload_status( - self, correlation_id, is_success, status_code, status_description, callback - ): - """ - Sends a POST request to a IoT Hub service endpoint to notify the status of the Storage SDK call for a blob upload. - - :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. - :param bool is_success: A boolean that indicates whether the file was uploaded successfully. - :param int status_code: A numeric status code that is the status for the upload of the file to storage. - :param str status_description: A description that corresponds to the status_code. - - :param callback: callback which is called when request has been fulfilled. - On success, this callback is called with the error=None. - On failure, this callback is called with error set to the cause of the failure. - - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - """ - logger.debug("HTTPPipeline notify_blob_upload_status called") - if not self._nucleus.pipeline_configuration.blob_upload: - # If this parameter is not set, that means this is not a device client. Upload to blob is not supported on module clients. - error = pipeline_exceptions.PipelineRuntimeError( - "notify_blob_upload_status called, but it is only supported for use with device clients. Ensure you are using a device client." - ) - return callback(error=error) - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation( - correlation_id=correlation_id, - is_success=is_success, - status_code=status_code, - status_description=status_description, - callback=on_complete, - ) - ) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_pipeline.py b/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_pipeline.py deleted file mode 100644 index fdea45b99..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_pipeline.py +++ /dev/null @@ -1,596 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from azure.iot.device.common.evented_callback import EventedCallback -from azure.iot.device.common.pipeline import ( - pipeline_nucleus, - pipeline_stages_base, - pipeline_ops_base, - pipeline_stages_mqtt, - pipeline_exceptions, -) -from . import ( - constant, - pipeline_stages_iothub, - pipeline_events_iothub, - pipeline_ops_iothub, - pipeline_stages_iothub_mqtt, -) - -logger = logging.getLogger(__name__) - - -class MQTTPipeline(object): - def __init__(self, pipeline_configuration): - """ - Constructor for instantiating a pipeline adapter object - :param auth_provider: The authentication provider - :param pipeline_configuration: The configuration generated based on user inputs - """ - - self.feature_enabled = { - constant.C2D_MSG: False, - constant.INPUT_MSG: False, - constant.METHODS: False, - constant.TWIN: False, - constant.TWIN_PATCHES: False, - } - - # Handlers - Will be set by Client after instantiation of this object - self.on_connected = None - self.on_disconnected = None - self.on_new_sastoken_required = None - self.on_background_exception = None - - self.on_c2d_message_received = None - self.on_input_message_received = None - self.on_method_request_received = None - self.on_twin_patch_received = None - - # Contains data and information shared globally within the pipeline - self._nucleus = pipeline_nucleus.PipelineNucleus(pipeline_configuration) - - self._pipeline = ( - # - # The root is always the root. By definition, it's the first stage in the pipeline. - # - pipeline_stages_base.PipelineRootStage(self._nucleus) - # - # SasTokenStage comes near the root by default because it should be as close - # to the top of the pipeline as possible, and does not need to be after anything. - # - .append_stage(pipeline_stages_base.SasTokenStage()) - # - # EnsureDesiredPropertiesStage needs to be above TwinRequestResponseStage because it - # sends GetTwinOperation ops and that stage handles those ops. - # - .append_stage(pipeline_stages_iothub.EnsureDesiredPropertiesStage()) - # - # TwinRequestResponseStage comes near the root by default because it doesn't need to be - # after anything - # - .append_stage(pipeline_stages_iothub.TwinRequestResponseStage()) - # - # CoordinateRequestAndResponseStage needs to be after TwinRequestResponseStage because - # TwinRequestResponseStage creates the request ops that CoordinateRequestAndResponseStage - # is coordinating. It needs to be before IoTHubMQTTTranslationStage because that stage - # operates on ops that CoordinateRequestAndResponseStage produces - # - .append_stage(pipeline_stages_base.CoordinateRequestAndResponseStage()) - # - # IoTHubMQTTTranslationStage comes here because this is the point where we can translate - # all operations directly into MQTT. After this stage, only pipeline_stages_base stages - # are allowed because IoTHubMQTTTranslationStage removes all the IoTHub-ness from the ops - # - .append_stage(pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage()) - # - # AutoConnectStage comes here because only MQTT ops have the need_connection flag set - # and this is the first place in the pipeline where we can guarantee that all network - # ops are MQTT ops. - # - .append_stage(pipeline_stages_base.AutoConnectStage()) - # - # ConnectionStateStage needs to be after AutoConnectStage because the AutoConnectStage - # can create ConnectOperations and we (may) want to queue connection related operations - # in the ConnectionStateStage - # - .append_stage(pipeline_stages_base.ConnectionStateStage()) - # - # RetryStage needs to be near the end because it's retrying low-level MQTT operations. - # - .append_stage(pipeline_stages_base.RetryStage()) - # - # OpTimeoutStage needs to be after RetryStage because OpTimeoutStage returns the timeout - # errors that RetryStage is watching for. - # - .append_stage(pipeline_stages_base.OpTimeoutStage()) - # - # MQTTTransportStage needs to be at the very end of the pipeline because this is where - # operations turn into network traffic - # - .append_stage(pipeline_stages_mqtt.MQTTTransportStage()) - ) - - # Define behavior for domain-specific events - def _on_pipeline_event(event): - if isinstance(event, pipeline_events_iothub.C2DMessageEvent): - if self.on_c2d_message_received: - self.on_c2d_message_received(event.message) - else: - logger.debug("C2D message event received with no handler. dropping.") - - elif isinstance(event, pipeline_events_iothub.InputMessageEvent): - if self.on_input_message_received: - self.on_input_message_received(event.message) - else: - logger.debug("input message event received with no handler. dropping.") - - elif isinstance(event, pipeline_events_iothub.MethodRequestEvent): - if self.on_method_request_received: - self.on_method_request_received(event.method_request) - else: - logger.debug("Method request event received with no handler. Dropping.") - - elif isinstance(event, pipeline_events_iothub.TwinDesiredPropertiesPatchEvent): - if self.on_twin_patch_received: - self.on_twin_patch_received(event.patch) - else: - logger.debug("Twin patch event received with no handler. Dropping.") - - else: - logger.debug("Dropping unknown pipeline event {}".format(event.name)) - - def _on_connected(): - if self.on_connected: - self.on_connected() - else: - logger.debug("IoTHub Pipeline was connected, but no handler was set") - - def _on_disconnected(): - if self.on_disconnected: - self.on_disconnected() - else: - logger.debug("IoTHub Pipeline was disconnected, but no handler was set") - - def _on_new_sastoken_required(): - if self.on_new_sastoken_required: - self.on_new_sastoken_required() - else: - logger.debug("IoTHub Pipeline requires new SASToken, but no handler was set") - - def _on_background_exception(e): - if self.on_background_exception: - self.on_background_exception(e) - else: - logger.debug( - "IoTHub Pipeline experienced background exception, but no handler was set" - ) - - # Set internal event handlers - self._pipeline.on_pipeline_event_handler = _on_pipeline_event - self._pipeline.on_connected_handler = _on_connected - self._pipeline.on_disconnected_handler = _on_disconnected - self._pipeline.on_new_sastoken_required_handler = _on_new_sastoken_required - self._pipeline.on_background_exception_handler = _on_background_exception - - # Initialize the pipeline - callback = EventedCallback() - op = pipeline_ops_base.InitializePipelineOperation(callback=callback) - self._pipeline.run_op(op) - callback.wait_for_completion() - - # Set the running flag - self._running = True - - def _verify_running(self): - if not self._running: - raise pipeline_exceptions.PipelineNotRunning( - "Cannot execute method - Pipeline is not running" - ) - - def shutdown(self, callback): - """Shut down the pipeline and clean up any resources. - - Once shut down, making any further calls on the pipeline will result in a - PipelineNotRunning exception being raised. - - There is currently no way to resume pipeline functionality once shutdown has occurred. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has already been shut down - - The shutdown process itself is not expected to fail under any normal condition, but if it - does, exceptions are not "raised", but rather, returned via the "error" parameter when - invoking "callback". - """ - self._verify_running() - logger.debug("Commencing shutdown of pipeline") - - def on_complete(op, error): - if not error: - # Only set the pipeline to not be running if the op was successful - self._running = False - callback(error=error) - - # NOTE: While we do run this operation, its functionality is incomplete. Some stages still - # need a response to this operation implemented. Additionally, there are other pipeline - # constructs other than Stages (e.g. Operations) which may have timers attached. These are - # lesser issues, but should be addressed at some point. - # TODO: Truly complete the shutdown implementation - self._pipeline.run_op(pipeline_ops_base.ShutdownPipelineOperation(callback=on_complete)) - - def connect(self, callback): - """ - Connect to the service. - - :param callback: callback which is called when the connection attempt is complete. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("Starting ConnectOperation on the pipeline") - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op(pipeline_ops_base.ConnectOperation(callback=on_complete)) - - def disconnect(self, callback): - """ - Disconnect from the service. - - Note that even if this fails for some reason, the client will be in a disconnected state. - - :param callback: callback which is called when the disconnection is complete. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - """ - self._verify_running() - logger.debug("Starting DisconnectOperation on the pipeline") - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op(pipeline_ops_base.DisconnectOperation(callback=on_complete)) - - def reauthorize_connection(self, callback): - """ - Reauthorize connection to the service by disconnecting and then reconnecting using - fresh credentials. - - This can be called regardless of connection state. If successful, the client will be - connected. If unsuccessful, the client will be disconnected. - - :param callback: callback which is called when the reauthorization attempt is complete. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("Starting ReauthorizeConnectionOperation on the pipeline") - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_base.ReauthorizeConnectionOperation(callback=on_complete) - ) - - def send_message(self, message, callback): - """ - Send a telemetry message to the service. - - :param message: message to send. - :param callback: callback which is called when the message publish attempt is complete. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("Starting SendD2CMessageOperation on the pipeline") - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_iothub.SendD2CMessageOperation(message=message, callback=on_complete) - ) - - def send_output_message(self, message, callback): - """ - Send an output message to the service. - - :param message: message to send. - :param callback: callback which is called when the message publish attempt is complete. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("Starting SendOutputMessageOperation on the pipeline") - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_iothub.SendOutputMessageOperation(message=message, callback=on_complete) - ) - - def send_method_response(self, method_response, callback): - """ - Send a method response to the service. - - :param method_response: the method response to send - :param callback: callback which is called when response publish attempt is complete. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("Starting SendMethodResponseOperation on the pipeline") - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_iothub.SendMethodResponseOperation( - method_response=method_response, callback=on_complete - ) - ) - - def get_twin(self, callback): - """ - Send a request for a full twin to the service. - - :param callback: callback which is called when request attempt is complete. - This callback should have two parameters. On success, this callback is called with the - requested twin and error=None. On failure, this callback is called with None for the - requested win and error set to the cause of the failure. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("Starting GetTwinOperation on the pipeline") - - def on_complete(op, error): - if error: - callback(error=error, twin=None) - else: - callback(twin=op.twin) - - self._pipeline.run_op(pipeline_ops_iothub.GetTwinOperation(callback=on_complete)) - - def patch_twin_reported_properties(self, patch, callback): - """ - Send a patch for a twin's reported properties to the service. - - :param patch: the reported properties patch to send - :param callback: callback which is called when the request attempt is complete. - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("Starting PatchTwinReportedPropertiesOperation on the pipeline") - - def on_complete(op, error): - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_iothub.PatchTwinReportedPropertiesOperation( - patch=patch, callback=on_complete - ) - ) - - # NOTE: Currently, this operation will retry itself indefinitely in the case of timeout - def enable_feature(self, feature_name, callback): - """ - Enable the given feature by subscribing to the appropriate topics. - - :param feature_name: one of the feature name constants from constant.py - :param callback: callback which is called when the feature is enabled - - :raises: ValueError if feature_name is invalid - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - """ - self._verify_running() - logger.debug("enable_feature {} called".format(feature_name)) - if feature_name not in self.feature_enabled: - raise ValueError("Invalid feature_name") - # TODO: What about if the feature is already enabled? - - def on_complete(op, error): - if error: - logger.warning( - "Subscribe for {} failed. Not enabling feature".format(feature_name) - ) - else: - self.feature_enabled[feature_name] = True - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_base.EnableFeatureOperation( - feature_name=feature_name, callback=on_complete - ) - ) - - # NOTE: Currently, this operation will retry itself indefinitely in the case of timeout - def disable_feature(self, feature_name, callback): - """ - Disable the given feature by subscribing to the appropriate topics. - :param callback: callback which is called when the feature is disabled - - :param feature_name: one of the feature name constants from constant.py - - :raises: ValueError if feature_name is invalid - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.PipelineNotRunning` if the - pipeline has previously been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` - """ - self._verify_running() - logger.debug("disable_feature {} called".format(feature_name)) - if feature_name not in self.feature_enabled: - raise ValueError("Invalid feature_name") - # TODO: What about if the feature is already disabled? - - def on_complete(op, error): - if error: - logger.warning( - "Error occurred while disabling feature. Unclear if subscription for {} is still alive or not".format( - feature_name - ) - ) - - # No matter what, mark the feature as disabled, even if there was an error. - # This is safer than only marking it disabled upon operation success, because an op - # could fail after successfully doing the network operations to change the subscription - # state, and then we would be stuck in a bad state. - self.feature_enabled[feature_name] = False - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_base.DisableFeatureOperation( - feature_name=feature_name, callback=on_complete - ) - ) - - @property - def pipeline_configuration(self): - """ - Pipeline Configuration for the pipeline. Note that while a new config object cannot be - provided (read-only), the values stored in the config object CAN be changed. - """ - return self._nucleus.pipeline_configuration - - @property - def connected(self): - """ - Read-only property to indicate if the transport is connected or not. - """ - return self._nucleus.connected diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_topic_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_topic_iothub.py deleted file mode 100644 index 9db638882..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_topic_iothub.py +++ /dev/null @@ -1,420 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from datetime import date -import urllib - -logger = logging.getLogger(__name__) - -# NOTE: Whenever using standard URL encoding via the urllib.parse.quote() API -# make sure to specify that there are NO safe values (e.g. safe=""). By default -# "/" is skipped in encoding, and that is not desirable. -# -# DO NOT use urllib.parse.quote_plus(), as it turns ' ' characters into '+', -# which is invalid for MQTT publishes. -# -# DO NOT use urllib.parse.unquote_plus(), as it turns '+' characters into ' ', -# which is also invalid. - - -# NOTE (Oct 2020): URL encoding policy is currently inconsistent in this module due to restrictions -# with the Hub, as Hub does not do URL decoding on most values. -# (see: https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT)). -# Currently, as much as possible is URL encoded while keeping in line with the policy outlined -# in the above linked wiki article. This is to say that Device ID and Module ID are never -# encoded, however other values are. By convention, it's probably fine to be encoding/decoding most -# values that are not Device ID or Module ID, since it won't make a difference in production as -# the narrow range of acceptable values for, say, status code, or request ID don't contain any -# characters that require URL encoding/decoding in the first place. Thus it doesn't break on Hub, -# but it's still done here as a client-side best practice - Hub will eventually be doing a new API -# that does correctly URL encode/decode all values, so it's not good to roll back more than -# is currently necessary to avoid errors. - - -def _get_topic_base(device_id, module_id=None): - """ - return the string that is at the beginning of all topics for this - device/module - """ - - # NOTE: Neither Device ID nor Module ID should be URL encoded in a topic string. - # See the repo wiki article for details: - # https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT) - topic = "devices/" + str(device_id) - if module_id: - topic = topic + "/modules/" + str(module_id) - return topic - - -def get_c2d_topic_for_subscribe(device_id): - """ - :return: The topic for cloud to device messages.It is of the format - "devices//messages/devicebound/#" - """ - return _get_topic_base(device_id) + "/messages/devicebound/#" - - -def get_input_topic_for_subscribe(device_id, module_id): - """ - :return: The topic for input messages. It is of the format - "devices//modules//inputs/#" - """ - return _get_topic_base(device_id, module_id) + "/inputs/#" - - -def get_method_topic_for_subscribe(): - """ - :return: The topic for ALL incoming methods. It is of the format - "$iothub/methods/POST/#" - """ - return "$iothub/methods/POST/#" - - -def get_twin_response_topic_for_subscribe(): - """ - :return: The topic for ALL incoming twin responses. It is of the format - "$iothub/twin/res/#" - """ - return "$iothub/twin/res/#" - - -def get_twin_patch_topic_for_subscribe(): - """ - :return: The topic for ALL incoming twin patches. It is of the format - "$iothub/twin/PATCH/properties/desired/# - """ - return "$iothub/twin/PATCH/properties/desired/#" - - -def get_telemetry_topic_for_publish(device_id, module_id): - """ - return the topic string used to publish telemetry - """ - return _get_topic_base(device_id, module_id) + "/messages/events/" - - -def get_method_topic_for_publish(request_id, status): - """ - :return: The topic for publishing method responses. It is of the format - "$iothub/methods/res//?$rid= - """ - return "$iothub/methods/res/{status}/?$rid={request_id}".format( - status=urllib.parse.quote(str(status), safe=""), - request_id=urllib.parse.quote(str(request_id), safe=""), - ) - - -# NOTE: Consider splitting this into separate logic for Twin Requests / Twin Patches -# This is the only method that is shared. Would probably simplify code if it was split. -# Please consider refactoring. -def get_twin_topic_for_publish(method, resource_location, request_id): - """ - :return: The topic for publishing twin requests / patches. It is of the format - "$iothub/twin/?$rid= - """ - return "$iothub/twin/{method}{resource_location}?$rid={request_id}".format( - method=method, - resource_location=resource_location, - request_id=urllib.parse.quote(str(request_id), safe=""), - ) - - -def is_c2d_topic(topic, device_id): - """ - Topics for c2d message are of the following format: - devices//messages/devicebound - :param topic: The topic string - """ - # Device ID is not URL encoded in a topic string - # See the repo wiki article for details: - # https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT) - if "devices/{}/messages/devicebound".format(device_id) in topic: - return True - return False - - -def is_input_topic(topic, device_id, module_id): - """ - Topics for inputs are of the following format: - devices//modules//inputs/ - :param topic: The topic string - """ - if not device_id or not module_id: - return False - # NOTE: Neither Device ID nor Module ID are URL encoded in a topic string. - # See the repo wiki article for details: - # https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT) - if "devices/{}/modules/{}/inputs/".format(device_id, module_id) in topic: - return True - return False - - -def is_method_topic(topic): - """ - Topics for methods are of the following format: - "$iothub/methods/POST/{method name}/?$rid={request id}" - - :param str topic: The topic string. - """ - if "$iothub/methods/POST" in topic: - return True - return False - - -def is_twin_response_topic(topic): - """Topics for twin responses are of the following format: - $iothub/twin/res/{status}/?$rid={rid} - - :param str topic: The topic string - """ - return topic.startswith("$iothub/twin/res/") - - -def is_twin_desired_property_patch_topic(topic): - """Topics for twin desired property patches are of the following format: - $iothub/twin/PATCH/properties/desired - - :param str topic: The topic string - """ - return topic.startswith("$iothub/twin/PATCH/properties/desired") - - -def get_input_name_from_topic(topic): - """ - Extract the input channel from the topic name - Topics for inputs are of the following format: - devices//modules//inputs/ - - :param topic: The topic string - """ - parts = topic.split("/") - if len(parts) > 5 and parts[4] == "inputs": - return urllib.parse.unquote(parts[5]) - else: - raise ValueError("topic has incorrect format") - - -def get_method_name_from_topic(topic): - """ - Extract the method name from the method topic. - Topics for methods are of the following format: - "$iothub/methods/POST/{method name}/?$rid={request id}" - - :param str topic: The topic string - """ - parts = topic.split("/") - if is_method_topic(topic) and len(parts) >= 4: - return urllib.parse.unquote(parts[3]) - else: - raise ValueError("topic has incorrect format") - - -def get_method_request_id_from_topic(topic): - """ - Extract the Request ID (RID) from the method topic. - Topics for methods are of the following format: - "$iothub/methods/POST/{method name}/?$rid={request id}" - - :param str topic: the topic string - :raises: ValueError if topic has incorrect format - :returns: request id from topic string - """ - parts = topic.split("/") - if is_method_topic(topic) and len(parts) >= 4: - - properties = _extract_properties(topic.split("?")[1]) - return properties["rid"] - else: - raise ValueError("topic has incorrect format") - - -def get_twin_request_id_from_topic(topic): - """ - Extract the Request ID (RID) from the twin response topic. - Topics for twin response are in the following format: - "$iothub/twin/res/{status}/?$rid={rid}" - - :param str topic: The topic string - :raises: ValueError if topic has incorrect format - :returns: request id from topic string - """ - parts = topic.split("/") - if is_twin_response_topic(topic) and len(parts) >= 4: - properties = _extract_properties(topic.split("?")[1]) - return properties["rid"] - else: - raise ValueError("topic has incorrect format") - - -def get_twin_status_code_from_topic(topic): - """ - Extract the status code from the twin response topic. - Topics for twin response are in the following format: - "$iothub/twin/res/{status}/?$rid={rid}" - - :param str topic: The topic string - :raises: ValueError if the topic has incorrect format - :returns status code from topic string - """ - parts = topic.split("/") - if is_twin_response_topic(topic) and len(parts) >= 4: - return urllib.parse.unquote(parts[3]) - else: - raise ValueError("topic has incorrect format") - - -def extract_message_properties_from_topic(topic, message_received): - """ - Extract key=value pairs from custom properties and set the properties on the received message. - For extracting values corresponding to keys the following rules are followed:- - If there is NO "=", the value is None - If there is "=" with no value, the value is an empty string - For anything else the value after "=" and before `&` is considered as the proper value - :param topic: The topic string - :param message_received: The message received with the payload in bytes - """ - - parts = topic.split("/") - # Input Message Topic - if len(parts) > 4 and parts[4] == "inputs": - if len(parts) > 6: - properties = parts[6] - else: - properties = None - # C2D Message Topic - elif len(parts) > 3 and parts[3] == "devicebound": - if len(parts) > 4: - properties = parts[4] - else: - properties = None - else: - raise ValueError("topic has incorrect format") - - # We do not want to extract values corresponding to these keys - ignored_extraction_values = ["$.to"] - - # NOTE: we cannot use urllib.parse.parse_qs because it always decodes '+' as ' ', - # and the behavior cannot be overridden. Must parse key/value pairs manually. - - if properties: - key_value_pairs = properties.split("&") - - for entry in key_value_pairs: - pair = entry.split("=") - key = urllib.parse.unquote(pair[0]) - if len(pair) > 1: - value = urllib.parse.unquote(pair[1]) - else: # Don't skip the key - value = None - - if key in ignored_extraction_values: - continue - elif key == "$.mid": - message_received.message_id = value - elif key == "$.cid": - message_received.correlation_id = value - elif key == "$.uid": - message_received.user_id = value - elif key == "$.ct": - message_received.content_type = value - elif key == "$.ce": - message_received.content_encoding = value - elif key == "$.exp": - message_received.expiry_time_utc = value - elif key == "iothub-ack": - message_received.ack = value - else: - message_received.custom_properties[key] = value - - -def encode_message_properties_in_topic(message_to_send, topic): - """ - uri-encode the system properties of a message as key-value pairs on the topic with defined keys. - Additionally if the message has user defined properties, the property keys and values shall be - uri-encoded and appended at the end of the above topic with the following convention: - '=&=&=(...)' - :param message_to_send: The message to send - :param topic: The topic which has not been encoded yet. For a device it looks like - "devices//messages/events/" and for a module it looks like - "devices//modules//messages/events/ - :return: The topic which has been uri-encoded - """ - system_properties = [] - if message_to_send.output_name: - system_properties.append(("$.on", str(message_to_send.output_name))) - if message_to_send.message_id: - system_properties.append(("$.mid", str(message_to_send.message_id))) - - if message_to_send.correlation_id: - system_properties.append(("$.cid", str(message_to_send.correlation_id))) - - if message_to_send.user_id: - system_properties.append(("$.uid", str(message_to_send.user_id))) - - if message_to_send.content_type: - system_properties.append(("$.ct", str(message_to_send.content_type))) - - if message_to_send.content_encoding: - system_properties.append(("$.ce", str(message_to_send.content_encoding))) - - if message_to_send.iothub_interface_id: - system_properties.append(("$.ifid", str(message_to_send.iothub_interface_id))) - - if message_to_send.expiry_time_utc: - system_properties.append( - ( - "$.exp", - message_to_send.expiry_time_utc.isoformat() # returns string - if isinstance(message_to_send.expiry_time_utc, date) - else message_to_send.expiry_time_utc, - ) - ) - - system_properties_encoded = urllib.parse.urlencode( - system_properties, quote_via=urllib.parse.quote - ) - topic += system_properties_encoded - - if message_to_send.custom_properties and len(message_to_send.custom_properties) > 0: - if system_properties and len(system_properties) > 0: - topic += "&" - - # Convert the custom properties to a sorted list in order to ensure the - # resulting ordering in the topic string is consistent across versions of Python. - # Convert to the properties to strings for safety. - custom_prop_seq = [ - (str(i[0]), str(i[1])) for i in list(message_to_send.custom_properties.items()) - ] - custom_prop_seq.sort() - - # Validate that string conversion has not created duplicate keys - keys = [i[0] for i in custom_prop_seq] - if len(keys) != len(set(keys)): - raise ValueError("Duplicate keys in custom properties!") - - user_properties_encoded = urllib.parse.urlencode( - custom_prop_seq, quote_via=urllib.parse.quote - ) - topic += user_properties_encoded - - return topic - - -def _extract_properties(properties_str): - """Return a dictionary of properties from a string in the format - ${key1}={value1}&${key2}={value2}...&${keyn}={valuen} - """ - d = {} - kv_pairs = properties_str.split("&") - - for entry in kv_pairs: - pair = entry.split("=") - key = urllib.parse.unquote(pair[0]).lstrip("$") - value = urllib.parse.unquote(pair[1]) - d[key] = value - - return d diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_events_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_events_iothub.py deleted file mode 100644 index 9eb847e64..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_events_iothub.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.common.pipeline import PipelineEvent - - -class C2DMessageEvent(PipelineEvent): - """ - A PipelineEvent object which represents an incoming C2D event. This object is probably - created by some converter stage based on a protocol-specific event - """ - - def __init__(self, message): - """ - Initializer for C2DMessageEvent objects. - - :param Message message: The Message object for the message that was received. - """ - super().__init__() - self.message = message - - -class InputMessageEvent(PipelineEvent): - """ - A PipelineEvent object which represents an incoming input message event. This object is probably - created by some converter stage based on a protocol-specific event - """ - - def __init__(self, message): - """ - Initializer for InputMessageEvent objects. - - :param Message message: The Message object for the message that was received. This message - is expected to have had the .input_name attribute set - """ - super().__init__() - self.message = message - - -class MethodRequestEvent(PipelineEvent): - """ - A PipelineEvent object which represents an incoming MethodRequest event. - This object is probably created by some converter stage based on a protocol-specific event. - """ - - def __init__(self, method_request): - super().__init__() - self.method_request = method_request - - -class TwinDesiredPropertiesPatchEvent(PipelineEvent): - """ - A PipelineEvent object which represents an incoming twin desired properties patch. This - object is probably created by some converter stage based on a protocol-specific event. - """ - - def __init__(self, patch): - super().__init__() - self.patch = patch diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py deleted file mode 100644 index 9556b4569..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py +++ /dev/null @@ -1,103 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.common.pipeline import PipelineOperation - - -class SendD2CMessageOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to send a telemetry message to an IoTHub or EdgeHub server. - - This operation is in the group of IoTHub operations because it is very specific to the IoTHub client - """ - - def __init__(self, message, callback): - """ - Initializer for SendD2CMessageOperation objects. - - :param Message message: The message that we're sending to the service - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.message = message - - -class SendOutputMessageOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to send an output message to an EdgeHub server. - - This operation is in the group of IoTHub operations because it is very specific to the IoTHub client - """ - - def __init__(self, message, callback): - """ - Initializer for SendOutputMessageOperation objects. - - :param Message message: The output message that we're sending to the service. The name of the output is - expected to be stored in the output_name attribute of this object - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.message = message - - -class SendMethodResponseOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to send a method response to an IoTHub or EdgeHub server. - - This operation is in the group of IoTHub operations because it is very specific to the IoTHub client. - """ - - def __init__(self, method_response, callback): - """ - Initializer for SendMethodResponseOperation objects. - - :param method_response: The method response to be sent to IoTHub/EdgeHub - :type method_response: MethodResponse - :param callback: The function that gets called when this operation is complete or has failed. - The callback function must accept a PipelineOperation object which indicates the specific operation has which - has completed or failed. - :type callback: Function/callable - """ - super().__init__(callback=callback) - self.method_response = method_response - - -class GetTwinOperation(PipelineOperation): - """ - A PipelineOperation object which represents a request to get a device twin or a module twin from an Azure - IoT Hub or Azure Iot Edge Hub service. - - :ivar twin: Upon completion, this contains the twin which was retrieved from the service. - :type twin: Twin - """ - - def __init__(self, callback): - """ - Initializer for GetTwinOperation objects. - """ - super().__init__(callback=callback) - self.twin = None - - -class PatchTwinReportedPropertiesOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to send a reported properties patch to the Azure - IoT Hub or Azure IoT Edge Hub service. - """ - - def __init__(self, patch, callback): - """ - Initializer for PatchTwinReportedPropertiesOperation object - - :param patch: The reported properties patch to send to the service. - :type patch: dict, str, int, float, bool, or None (JSON compatible values) - """ - super().__init__(callback=callback) - self.patch = patch diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub_http.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub_http.py deleted file mode 100644 index 7f66c2294..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub_http.py +++ /dev/null @@ -1,79 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.common.pipeline import PipelineOperation - - -class MethodInvokeOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to send a method invoke to an IoTHub or EdgeHub server. - - This operation is in the group of EdgeHub operations because it is very specific to the EdgeHub client. - """ - - def __init__(self, target_device_id, target_module_id, method_params, callback): - """ - Initializer for MethodInvokeOperation objects. - - :param str target_device_id: The device id of the target device/module - :param str target_module_id: The module id of the target module - :param method_params: The parameters used to invoke the method, as defined by the IoT Hub specification. - :param callback: The function that gets called when this operation is complete or has failed. - The callback function must accept a PipelineOperation object which indicates the specific operation has which - has completed or failed. - :type callback: Function/callable - """ - super().__init__(callback=callback) - self.target_device_id = target_device_id - self.target_module_id = target_module_id - self.method_params = method_params - self.method_response = None - - -class GetStorageInfoOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to get the storage information from IoT Hub. - """ - - def __init__(self, blob_name, callback): - """ - Initializer for GetStorageInfo objects. - - :param str blob_name: The name of the blob that will be created in Azure Storage - :param callback: The function that gets called when this operation is complete or has failed. - The callback function must accept a PipelineOperation object which indicates the specific operation has which - has completed or failed. - :type callback: Function/callable - - :ivar dict storage_info: Upon completion, this contains the storage information which was retrieved from the service. - """ - super().__init__(callback=callback) - self.blob_name = blob_name - self.storage_info = None - - -class NotifyBlobUploadStatusOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to get the storage information from IoT Hub. - """ - - def __init__(self, correlation_id, is_success, status_code, status_description, callback): - """ - Initializer for GetStorageInfo objects. - - :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. - :param bool is_success: A boolean that indicates whether the file was uploaded successfully. - :param int request_status_code: A numeric status code that is the status for the upload of the file to storage. - :param str status_description: A description that corresponds to the status_code. - :param callback: The function that gets called when this operation is complete or has failed. - The callback function must accept a PipelineOperation object which indicates the specific operation has which - has completed or failed. - :type callback: Function/callable - """ - super().__init__(callback=callback) - self.correlation_id = correlation_id - self.is_success = is_success - self.request_status_code = status_code - self.status_description = status_description diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py deleted file mode 100644 index da3af4839..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py +++ /dev/null @@ -1,210 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import json -import logging -from azure.iot.device.common.pipeline import ( - pipeline_events_base, - pipeline_ops_base, - PipelineStage, - pipeline_thread, -) -from azure.iot.device import exceptions -from . import pipeline_events_iothub, pipeline_ops_iothub -from . import constant - -logger = logging.getLogger(__name__) - - -class EnsureDesiredPropertiesStage(PipelineStage): - """ - Pipeline stage Responsible for making sure that desired properties are always kept up to date. - It does this by sending down a GetTwinOperation after a connection is reestablished, and, if - the desired properties have changed since the last time a patch was received, it will send up - an artificial patch event to send those updated properties to the app. - """ - - def __init__(self): - self.last_version_seen = None - self.pending_get_request = None - super().__init__() - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if self.nucleus.pipeline_configuration.ensure_desired_properties: - if isinstance(op, pipeline_ops_base.EnableFeatureOperation): - # Ensure_desired_properties enables twin patches, when true, by setting last version - # to -1. The ConnectedEvent handler sees this and sends a GetTwinOperation to refresh - # desired properties. Setting ensure_desired_properties to false causes the GetTwinOp - # to not be sent. The rest of the functions in this stage stem from the GetTwinOperation, - # so disabling ensure_desired_properties effectively disables this stage. - - if op.feature_name == constant.TWIN_PATCHES: - logger.debug( - "{}: enabling twin patches. setting last_version_seen".format(self.name) - ) - self.last_version_seen = -1 - self.send_op_down(op) - - @pipeline_thread.runs_on_pipeline_thread - def _ensure_get_op(self): - """ - Function which makes sure we have a GetTwin operation in progress. If we've - already sent one down and we're waiting for it to return, we don't want to send - a new one down. This is because layers below us (especially CoordinateRequestAndResponseStage) - will do everything they can to ensure we get a response on the already-pending - GetTwinOperation. - """ - if not self.pending_get_request: - logger.info("{}: sending twin GET to ensure freshness".format(self.name)) - self.pending_get_request = pipeline_ops_iothub.GetTwinOperation( - callback=self._on_get_twin_complete - ) - self.send_op_down(self.pending_get_request) - else: - logger.debug( - "{}: Outstanding twin GET already exists. Not sending anything".format(self.name) - ) - - @pipeline_thread.runs_on_pipeline_thread - def _on_get_twin_complete(self, op, error): - """ - Function that gets called when a GetTwinOperation _that_we_initiated_ is complete. - This is where we compare $version values and decide if we want to create an artificial - TwinDesiredPropertiesPatchEvent or not. - """ - - self.pending_get_request = None - if error and self.nucleus.connected: - # If the GetTwinOperation failed and the client is connected, we blindly try again as - # long as we are connected. We run the risk of repeating this forever and might need - # to add logic to "give up" after some number of failures, but we don't have any real - # reason to add that yet. - logger.debug("{}: Twin GET failed with error {}. Resubmitting.".format(self, error)) - self._ensure_get_op() - elif error and not self.nucleus.connected: - # If the GetTwinOperation failed and the client is in any state but connected, - # (e.g. connecting, disconnecting, etc.) we consider the operation completed. - logger.debug( - "{}: Twin GET failed with error {}. Giving up, as pipeline is not connected." - ) - else: - # If the GetTwinOperation is successful, we compare the $version values and create - # an artificial patch if the versions do not match. - logger.debug("{} Twin GET response received. Checking versions".format(self)) - new_version = op.twin["desired"]["$version"] - logger.debug( - "{}: old version = {}, new version = {}".format( - self.name, self.last_version_seen, new_version - ) - ) - if self.last_version_seen != new_version: - # The twin we received has different (presumably newer) desired properties. - # Make an artificial patch and send it up - - logger.debug("{}: Version changed. Sending up new patch event".format(self.name)) - self.last_version_seen = new_version - self.send_event_up( - pipeline_events_iothub.TwinDesiredPropertiesPatchEvent(op.twin["desired"]) - ) - - @pipeline_thread.runs_on_pipeline_thread - def _handle_pipeline_event(self, event): - if self.nucleus.pipeline_configuration.ensure_desired_properties: - if isinstance(event, pipeline_events_iothub.TwinDesiredPropertiesPatchEvent): - # remember the $version when we get a patch. - version = event.patch["$version"] - logger.debug( - "{}: Desired patch received. Saving $version={}".format(self.name, version) - ) - self.last_version_seen = version - elif isinstance(event, pipeline_events_base.ConnectedEvent): - # If last_version_seen is truthy, that means we've seen desired property patches - # before (or we've enabled them at least). If this is the case, get the twin to - # see if the desired props have been updated. - if self.last_version_seen: - logger.info("{}: Reconnected. Getting twin".format(self.name)) - self._ensure_get_op() - self.send_event_up(event) - - -class TwinRequestResponseStage(PipelineStage): - """ - PipelineStage which handles twin operations. In particular, it converts twin GET and PATCH - operations into RequestAndResponseOperation operations. This is done at the IoTHub level because - there is nothing protocol-specific about this code. The protocol-specific implementation - for twin requests and responses is handled inside IoTHubMQTTTranslationStage, when it converts - the RequestOperation to a protocol-specific send operation and when it converts the - protocol-specific receive event into an ResponseEvent event. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - def map_twin_error(error, twin_op): - if error: - return error - elif twin_op.status_code >= 300: - # TODO map error codes to correct exceptions - logger.info("Error {} received from twin operation".format(twin_op.status_code)) - logger.info("response body: {}".format(twin_op.response_body)) - return exceptions.ServiceError( - "twin operation returned status {}".format(twin_op.status_code) - ) - - if isinstance(op, pipeline_ops_iothub.GetTwinOperation): - - # Alias to avoid overload within the callback below - # CT-TODO: remove the need for this with better callback semantics - op_waiting_for_response = op - - def on_twin_response(op, error): - logger.debug("{}({}): Got response for GetTwinOperation".format(self.name, op.name)) - error = map_twin_error(error=error, twin_op=op) - if not error: - op_waiting_for_response.twin = json.loads(op.response_body.decode("utf-8")) - op_waiting_for_response.complete(error=error) - - self.send_op_down( - pipeline_ops_base.RequestAndResponseOperation( - request_type=constant.TWIN, - method="GET", - resource_location="/", - request_body=" ", - callback=on_twin_response, - ) - ) - - elif isinstance(op, pipeline_ops_iothub.PatchTwinReportedPropertiesOperation): - - # Alias to avoid overload within the callback below - # CT-TODO: remove the need for this with better callback semantics - op_waiting_for_response = op - - def on_twin_response(op, error): - logger.debug( - "{}({}): Got response for PatchTwinReportedPropertiesOperation operation".format( - self.name, op.name - ) - ) - error = map_twin_error(error=error, twin_op=op) - op_waiting_for_response.complete(error=error) - - logger.debug( - "{}({}): Sending reported properties patch: {}".format(self.name, op.name, op.patch) - ) - - self.send_op_down( - pipeline_ops_base.RequestAndResponseOperation( - request_type=constant.TWIN, - method="PATCH", - resource_location="/properties/reported/", - request_body=json.dumps(op.patch), - callback=on_twin_response, - ) - ) - - else: - super()._run_op(op) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py deleted file mode 100644 index 01ecf4b2f..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py +++ /dev/null @@ -1,192 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import json -import urllib -from azure.iot.device.common.pipeline import ( - pipeline_ops_http, - PipelineStage, - pipeline_thread, -) -from . import pipeline_ops_iothub_http, http_path_iothub, http_map_error -from azure.iot.device import exceptions -from azure.iot.device import constant as pkg_constant -from azure.iot.device import user_agent - - -logger = logging.getLogger(__name__) - - -@pipeline_thread.runs_on_pipeline_thread -def map_http_error(error, http_op): - if error: - return error - elif http_op.status_code >= 300: - translated_error = http_map_error.translate_error(http_op.status_code, http_op.reason) - return exceptions.ServiceError( - "HTTP operation returned: {} {}".format(http_op.status_code, translated_error) - ) - - -class IoTHubHTTPTranslationStage(PipelineStage): - """ - PipelineStage which converts other Iot and EdgeHub operations into HTTP operations. This stage also - converts http pipeline events into Iot and EdgeHub pipeline events. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_iothub_http.MethodInvokeOperation): - logger.debug( - "{}({}): Translating Method Invoke Operation for HTTP.".format(self.name, op.name) - ) - query_params = "api-version={apiVersion}".format( - apiVersion=pkg_constant.IOTHUB_API_VERSION - ) - # if the target is a module. - - body = json.dumps(op.method_params) - path = http_path_iothub.get_method_invoke_path(op.target_device_id, op.target_module_id) - # NOTE: we do not add the sas Authorization header here. Instead we add it later on in - # the HTTPTransportStage - x_ms_edge_string = "{deviceId}/{moduleId}".format( - deviceId=self.nucleus.pipeline_configuration.device_id, - moduleId=self.nucleus.pipeline_configuration.module_id, - ) # these are the identifiers of the current module - user_agent_string = urllib.parse.quote_plus( - user_agent.get_iothub_user_agent() - + str(self.nucleus.pipeline_configuration.product_info) - ) - # Method Invoke must be addressed to the gateway hostname because it is an Edge op - headers = { - "Host": self.nucleus.pipeline_configuration.gateway_hostname, - "Content-Type": "application/json", - "Content-Length": str(len(str(body))), - "x-ms-edge-moduleId": x_ms_edge_string, - "User-Agent": user_agent_string, - } - op_waiting_for_response = op - - def on_request_response(op, error): - logger.debug( - "{}({}): Got response for MethodInvokeOperation".format(self.name, op.name) - ) - error = map_http_error(error=error, http_op=op) - if not error: - op_waiting_for_response.method_response = json.loads(op.response_body) - op_waiting_for_response.complete(error=error) - - self.send_op_down( - pipeline_ops_http.HTTPRequestAndResponseOperation( - method="POST", - path=path, - headers=headers, - body=body, - query_params=query_params, - callback=on_request_response, - ) - ) - - elif isinstance(op, pipeline_ops_iothub_http.GetStorageInfoOperation): - logger.debug( - "{}({}): Translating Get Storage Info Operation to HTTP.".format(self.name, op.name) - ) - query_params = "api-version={apiVersion}".format( - apiVersion=pkg_constant.IOTHUB_API_VERSION - ) - path = http_path_iothub.get_storage_info_for_blob_path( - self.nucleus.pipeline_configuration.device_id - ) - body = json.dumps({"blobName": op.blob_name}) - user_agent_string = urllib.parse.quote_plus( - user_agent.get_iothub_user_agent() - + str(self.nucleus.pipeline_configuration.product_info) - ) - headers = { - "Host": self.nucleus.pipeline_configuration.hostname, - "Accept": "application/json", - "Content-Type": "application/json", - "Content-Length": str(len(str(body))), - "User-Agent": user_agent_string, - } - - op_waiting_for_response = op - - def on_request_response(op, error): - logger.debug( - "{}({}): Got response for GetStorageInfoOperation".format(self.name, op.name) - ) - error = map_http_error(error=error, http_op=op) - if not error: - op_waiting_for_response.storage_info = json.loads(op.response_body) - op_waiting_for_response.complete(error=error) - - self.send_op_down( - pipeline_ops_http.HTTPRequestAndResponseOperation( - method="POST", - path=path, - headers=headers, - body=body, - query_params=query_params, - callback=on_request_response, - ) - ) - - elif isinstance(op, pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation): - logger.debug( - "{}({}): Translating Get Storage Info Operation to HTTP.".format(self.name, op.name) - ) - query_params = "api-version={apiVersion}".format( - apiVersion=pkg_constant.IOTHUB_API_VERSION - ) - path = http_path_iothub.get_notify_blob_upload_status_path( - self.nucleus.pipeline_configuration.device_id - ) - body = json.dumps( - { - "correlationId": op.correlation_id, - "isSuccess": op.is_success, - "statusCode": op.request_status_code, - "statusDescription": op.status_description, - } - ) - user_agent_string = urllib.parse.quote_plus( - user_agent.get_iothub_user_agent() - + str(self.nucleus.pipeline_configuration.product_info) - ) - - # NOTE we do not add the sas Authorization header here. Instead we add it later on in - # the HTTPTransportStage - headers = { - "Host": self.nucleus.pipeline_configuration.hostname, - "Content-Type": "application/json; charset=utf-8", - "Content-Length": str(len(str(body))), - "User-Agent": user_agent_string, - } - op_waiting_for_response = op - - def on_request_response(op, error): - logger.debug( - "{}({}): Got response for GetStorageInfoOperation".format(self.name, op.name) - ) - error = map_http_error(error=error, http_op=op) - op_waiting_for_response.complete(error=error) - - self.send_op_down( - pipeline_ops_http.HTTPRequestAndResponseOperation( - method="POST", - path=path, - headers=headers, - body=body, - query_params=query_params, - callback=on_request_response, - ) - ) - - else: - # All other operations get passed down - self.send_op_down(op) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py deleted file mode 100644 index 128d2c04d..000000000 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py +++ /dev/null @@ -1,232 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import json -import urllib -from azure.iot.device.common.pipeline import ( - pipeline_events_base, - pipeline_ops_base, - pipeline_ops_mqtt, - pipeline_events_mqtt, - PipelineStage, - pipeline_thread, -) -from azure.iot.device.iothub.models import Message, MethodRequest -from . import pipeline_ops_iothub, pipeline_events_iothub, mqtt_topic_iothub -from . import constant as pipeline_constant -from . import exceptions as pipeline_exceptions -from azure.iot.device import constant as pkg_constant -from azure.iot.device import user_agent - -logger = logging.getLogger(__name__) - - -class IoTHubMQTTTranslationStage(PipelineStage): - """ - PipelineStage which converts other Iot and IoTHub operations into MQTT operations. This stage also - converts mqtt pipeline events into Iot and IoTHub pipeline events. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - - if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - - if self.nucleus.pipeline_configuration.module_id: - # Module Format - client_id = "{}/{}".format( - self.nucleus.pipeline_configuration.device_id, - self.nucleus.pipeline_configuration.module_id, - ) - else: - # Device Format - client_id = self.nucleus.pipeline_configuration.device_id - - query_param_seq = [] - - # Apply query parameters (i.e. key1=value1&key2=value2...&keyN=valueN format) - custom_product_info = str(self.nucleus.pipeline_configuration.product_info) - if custom_product_info.startswith( - pkg_constant.DIGITAL_TWIN_PREFIX - ): # Digital Twin Stuff - query_param_seq.append(("api-version", pkg_constant.DIGITAL_TWIN_API_VERSION)) - query_param_seq.append(("DeviceClientType", user_agent.get_iothub_user_agent())) - query_param_seq.append( - (pkg_constant.DIGITAL_TWIN_QUERY_HEADER, custom_product_info) - ) - else: - query_param_seq.append(("api-version", pkg_constant.IOTHUB_API_VERSION)) - query_param_seq.append( - ("DeviceClientType", user_agent.get_iothub_user_agent() + custom_product_info) - ) - - # NOTE: Client ID (including the device and/or module ids that are in it) - # is NOT url encoded as part of the username. Neither is the hostname. - # The sequence of key/value property pairs (query_param_seq) however, MUST have all - # keys and values URL encoded. - # See the repo wiki article for details: - # https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT) - username = "{hostname}/{client_id}/?{query_params}".format( - hostname=self.nucleus.pipeline_configuration.hostname, - client_id=client_id, - query_params=urllib.parse.urlencode(query_param_seq, quote_via=urllib.parse.quote), - ) - - # Dynamically attach the derived MQTT values to the InitializePipelineOperation - # to be used later down the pipeline - op.username = username - op.client_id = client_id - - self.send_op_down(op) - - elif isinstance(op, pipeline_ops_iothub.SendD2CMessageOperation) or isinstance( - op, pipeline_ops_iothub.SendOutputMessageOperation - ): - # Convert SendTelemetry and SendOutputMessageOperation operations into MQTT Publish operations - telemetry_topic = mqtt_topic_iothub.get_telemetry_topic_for_publish( - device_id=self.nucleus.pipeline_configuration.device_id, - module_id=self.nucleus.pipeline_configuration.module_id, - ) - topic = mqtt_topic_iothub.encode_message_properties_in_topic( - op.message, telemetry_topic - ) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, - topic=topic, - payload=op.message.data, - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_iothub.SendMethodResponseOperation): - # Sending a Method Response gets translated into an MQTT Publish operation - topic = mqtt_topic_iothub.get_method_topic_for_publish( - op.method_response.request_id, op.method_response.status - ) - payload = json.dumps(op.method_response.payload) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, topic=topic, payload=payload - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_base.EnableFeatureOperation): - # Enabling a feature gets translated into an MQTT subscribe operation - topic = self._get_feature_subscription_topic(op.feature_name) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTSubscribeOperation, topic=topic - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_base.DisableFeatureOperation): - # Disabling a feature gets turned into an MQTT unsubscribe operation - topic = self._get_feature_subscription_topic(op.feature_name) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTUnsubscribeOperation, topic=topic - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_base.RequestOperation): - if op.request_type == pipeline_constant.TWIN: - topic = mqtt_topic_iothub.get_twin_topic_for_publish( - method=op.method, - resource_location=op.resource_location, - request_id=op.request_id, - ) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, - topic=topic, - payload=op.request_body, - ) - self.send_op_down(worker_op) - else: - raise pipeline_exceptions.OperationError( - "RequestOperation request_type {} not supported".format(op.request_type) - ) - - else: - # All other operations get passed down - super()._run_op(op) - - @pipeline_thread.runs_on_pipeline_thread - def _get_feature_subscription_topic(self, feature): - if feature == pipeline_constant.C2D_MSG: - return mqtt_topic_iothub.get_c2d_topic_for_subscribe( - self.nucleus.pipeline_configuration.device_id - ) - elif feature == pipeline_constant.INPUT_MSG: - return mqtt_topic_iothub.get_input_topic_for_subscribe( - self.nucleus.pipeline_configuration.device_id, - self.nucleus.pipeline_configuration.module_id, - ) - elif feature == pipeline_constant.METHODS: - return mqtt_topic_iothub.get_method_topic_for_subscribe() - elif feature == pipeline_constant.TWIN: - return mqtt_topic_iothub.get_twin_response_topic_for_subscribe() - elif feature == pipeline_constant.TWIN_PATCHES: - return mqtt_topic_iothub.get_twin_patch_topic_for_subscribe() - else: - logger.warning("Cannot retrieve MQTT topic for subscription to invalid feature") - raise pipeline_exceptions.OperationError( - "Trying to enable/disable invalid feature - {}".format(feature) - ) - - @pipeline_thread.runs_on_pipeline_thread - def _handle_pipeline_event(self, event): - """ - Pipeline Event handler function to convert incoming MQTT messages into the appropriate IoTHub - events, based on the topic of the message - """ - # TODO: should we always be decoding the payload? Seems strange to only sometimes do it. - # Is there value to the user getting the original bytestring from the wire? - if isinstance(event, pipeline_events_mqtt.IncomingMQTTMessageEvent): - topic = event.topic - device_id = self.nucleus.pipeline_configuration.device_id - module_id = self.nucleus.pipeline_configuration.module_id - - if mqtt_topic_iothub.is_c2d_topic(topic, device_id): - message = Message(event.payload) - mqtt_topic_iothub.extract_message_properties_from_topic(topic, message) - self.send_event_up(pipeline_events_iothub.C2DMessageEvent(message)) - - elif mqtt_topic_iothub.is_input_topic(topic, device_id, module_id): - message = Message(event.payload) - mqtt_topic_iothub.extract_message_properties_from_topic(topic, message) - message.input_name = mqtt_topic_iothub.get_input_name_from_topic(topic) - self.send_event_up(pipeline_events_iothub.InputMessageEvent(message)) - - elif mqtt_topic_iothub.is_method_topic(topic): - request_id = mqtt_topic_iothub.get_method_request_id_from_topic(topic) - method_name = mqtt_topic_iothub.get_method_name_from_topic(topic) - method_received = MethodRequest( - request_id=request_id, - name=method_name, - payload=json.loads(event.payload.decode("utf-8")), - ) - self.send_event_up(pipeline_events_iothub.MethodRequestEvent(method_received)) - - elif mqtt_topic_iothub.is_twin_response_topic(topic): - request_id = mqtt_topic_iothub.get_twin_request_id_from_topic(topic) - status_code = int(mqtt_topic_iothub.get_twin_status_code_from_topic(topic)) - self.send_event_up( - pipeline_events_base.ResponseEvent( - request_id=request_id, status_code=status_code, response_body=event.payload - ) - ) - - elif mqtt_topic_iothub.is_twin_desired_property_patch_topic(topic): - self.send_event_up( - pipeline_events_iothub.TwinDesiredPropertiesPatchEvent( - patch=json.loads(event.payload.decode("utf-8")) - ) - ) - - else: - logger.debug("Unknown topic: {} passing up to next handler".format(topic)) - self.send_event_up(event) - - else: - # all other messages get passed up - self.send_event_up(event) diff --git a/azure-iot-device/azure/iot/device/iothub/sync_clients.py b/azure-iot-device/azure/iot/device/iothub/sync_clients.py deleted file mode 100644 index a90cc04a0..000000000 --- a/azure-iot-device/azure/iot/device/iothub/sync_clients.py +++ /dev/null @@ -1,721 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains user-facing synchronous clients for the -Azure IoTHub Device SDK for Python. -""" - -import logging -import deprecation -from .abstract_clients import ( - AbstractIoTHubClient, - AbstractIoTHubDeviceClient, - AbstractIoTHubModuleClient, -) -from .models import Message -from .inbox_manager import InboxManager -from .sync_inbox import SyncClientInbox, InboxEmpty -from . import sync_handler_manager -from .pipeline import constant as pipeline_constant -from .pipeline import exceptions as pipeline_exceptions -from azure.iot.device import exceptions -from azure.iot.device.common.evented_callback import EventedCallback -from azure.iot.device import constant as device_constant - - -logger = logging.getLogger(__name__) - - -def handle_result(callback): - try: - return callback.wait_for_completion() - except pipeline_exceptions.ConnectionDroppedError as e: - raise exceptions.ConnectionDroppedError("Lost connection to IoTHub") from e - except pipeline_exceptions.ConnectionFailedError as e: - raise exceptions.ConnectionFailedError("Could not connect to IoTHub") from e - except pipeline_exceptions.NoConnectionError as e: - raise exceptions.NoConnectionError("Client is not connected to IoTHub") from e - except pipeline_exceptions.UnauthorizedError as e: - raise exceptions.CredentialError("Credentials invalid, could not connect") from e - except pipeline_exceptions.ProtocolClientError as e: - raise exceptions.ClientError("Error in the IoTHub client") from e - except pipeline_exceptions.TlsExchangeAuthError as e: - raise exceptions.ClientError("Error in the IoTHub client due to TLS exchanges.") from e - except pipeline_exceptions.ProtocolProxyError as e: - raise exceptions.ClientError( - "Error in the IoTHub client raised due to proxy connections." - ) from e - except pipeline_exceptions.PipelineNotRunning as e: - raise exceptions.ClientError("Client has already been shut down") from e - except pipeline_exceptions.OperationCancelled as e: - raise exceptions.OperationCancelled("Operation was cancelled before completion") from e - except pipeline_exceptions.OperationTimeout as e: - raise exceptions.OperationTimeout("Could not complete operation before timeout") from e - except Exception as e: - raise exceptions.ClientError("Unexpected failure") from e - - -class GenericIoTHubClient(AbstractIoTHubClient): - """A superclass representing a generic synchronous client. - This class needs to be extended for specific clients. - """ - - def __init__(self, **kwargs): - """Initializer for a generic synchronous client. - - This initializer should not be called directly. - Instead, use one of the 'create_from_' classmethods to instantiate - - :param mqtt_pipeline: The MQTTPipeline used for the client - :type mqtt_pipeline: :class:`azure.iot.device.iothub.pipeline.MQTTPipeline` - :param http_pipeline: The HTTPPipeline used for the client - :type http_pipeline: :class:`azure.iot.device.iothub.pipeline.HTTPPipeline` - """ - # Depending on the subclass calling this __init__, there could be different arguments, - # and the super() call could call a different class, due to the different MROs - # in the class hierarchies of different clients. Thus, args here must be passed along as - # **kwargs. - super().__init__(**kwargs) - self._inbox_manager = InboxManager(inbox_type=SyncClientInbox) - self._handler_manager = sync_handler_manager.SyncHandlerManager(self._inbox_manager) - - # Set pipeline handlers for client events - self._mqtt_pipeline.on_connected = self._on_connected - self._mqtt_pipeline.on_disconnected = self._on_disconnected - self._mqtt_pipeline.on_new_sastoken_required = self._on_new_sastoken_required - self._mqtt_pipeline.on_background_exception = self._on_background_exception - - # Set pipeline handlers for data receives - self._mqtt_pipeline.on_method_request_received = self._inbox_manager.route_method_request - self._mqtt_pipeline.on_twin_patch_received = self._inbox_manager.route_twin_patch - - def _enable_feature(self, feature_name): - """Enable an Azure IoT Hub feature. - - This is a synchronous call, meaning that this function will not return until the feature - has been enabled. - - :param feature_name: The name of the feature to enable. - See azure.iot.device.common.pipeline.constant for possible values - """ - logger.info("Enabling feature:" + feature_name + "...") - if not self._mqtt_pipeline.feature_enabled[feature_name]: - callback = EventedCallback() - self._mqtt_pipeline.enable_feature(feature_name, callback=callback) - callback.wait_for_completion() - - logger.info("Successfully enabled feature:" + feature_name) - else: - # This branch shouldn't be reached, but in case it is, log it - logger.info("Feature ({}) already disabled - skipping".format(feature_name)) - - def _disable_feature(self, feature_name): - """Disable an Azure IoT Hub feature - - This is a synchronous call, meaning that this function will not return until the feature - has been disabled. - - :param feature_name: The name of the feature to disable. - See azure.iot.device.common.pipeline.constant for possible values - """ - logger.info("Disabling feature: {}...".format(feature_name)) - if self._mqtt_pipeline.feature_enabled[feature_name]: - # Disable the feature if not already disabled - callback = EventedCallback() - self._mqtt_pipeline.disable_feature(feature_name, callback=callback) - callback.wait_for_completion() - - logger.info("Successfully disabled feature: {}".format(feature_name)) - else: - # This branch shouldn't be reached, but in case it is, log it - logger.info("Feature ({}) already disabled - skipping".format(feature_name)) - - def _generic_receive_handler_setter(self, handler_name, feature_name, new_handler): - """Set a receive handler on the handler manager and enable the corresponding feature. - - This is a synchronous call, meaning that this function will not return until the feature - has been enabled (if necessary). - - :param str handler_name: The name of the handler on the handler manager to set - :param str feature_name: The name of the pipeline feature that corresponds to the handler - :param new_handler: The function to be set as the handler - """ - self._check_receive_mode_is_handler() - # Set the handler on the handler manager - setattr(self._handler_manager, handler_name, new_handler) - - # Enable the feature if necessary - if new_handler is not None and not self._mqtt_pipeline.feature_enabled[feature_name]: - self._enable_feature(feature_name) - - # Disable the feature if necessary - elif new_handler is None and self._mqtt_pipeline.feature_enabled[feature_name]: - self._disable_feature(feature_name) - - def shutdown(self): - """Shut down the client for graceful exit. - - Once this method is called, any attempts at further client calls will result in a - ClientError being raised - - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Initiating client shutdown") - # Note that client disconnect does the following: - # - Disconnects the pipeline - # - Resolves all pending receiver handler calls - # - Stops receiver handler threads - self.disconnect() - - # Note that shutting down the following: - # - Disconnects the MQTT pipeline - # - Stops MQTT pipeline threads - logger.debug("Beginning pipeline shutdown operation") - callback = EventedCallback() - self._mqtt_pipeline.shutdown(callback=callback) - handle_result(callback) - logger.debug("Completed pipeline shutdown operation") - - # Stop the Client Event handlers now that everything else is completed - self._handler_manager.stop(receiver_handlers_only=False) - - # Yes, that means the pipeline is disconnected twice (well, actually three times if you - # consider that the client-level disconnect causes two pipeline-level disconnects for - # reasons explained in comments in the client's .disconnect() method). - # - # This last disconnect that occurs as a result of the pipeline shutdown is a bit different - # from the first though, in that it's more "final" and can't simply just be reconnected. - - # Note also that only the MQTT pipeline is shut down. The reason is twofold: - # 1. There are no known issues related to graceful exit if the HTTP pipeline is not - # explicitly shut down - # 2. The HTTP pipeline is planned for eventual removal from the client - # In light of these two facts, it seemed irrelevant to spend time implementing shutdown - # capability for HTTP pipeline. - logger.info("Client shutdown complete") - - def connect(self): - """Connects the client to an Azure IoT Hub or Azure IoT Edge Hub instance. - - The destination is chosen based on the credentials passed via the auth_provider parameter - that was provided when this object was initialized. - - This is a synchronous call, meaning that this function will not return until the connection - to the service has been completely established. - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if the connection times out. - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Connecting to Hub...") - - callback = EventedCallback() - self._mqtt_pipeline.connect(callback=callback) - handle_result(callback) - - logger.info("Successfully connected to Hub") - - def disconnect(self): - """Disconnect the client from the Azure IoT Hub or Azure IoT Edge Hub instance. - - It is recommended that you make sure to call this function when you are completely done - with the your client instance. - - This is a synchronous call, meaning that this function will not return until the connection - to the service has been completely closed. - - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Disconnecting from Hub...") - - logger.debug("Executing initial disconnect") - callback = EventedCallback() - self._mqtt_pipeline.disconnect(callback=callback) - handle_result(callback) - logger.debug("Successfully executed initial disconnect") - - # Note that in the process of stopping the handlers and resolving pending calls - # a user-supplied handler may cause a reconnection to occur - logger.debug("Stopping handlers...") - self._handler_manager.stop(receiver_handlers_only=True) - logger.debug("Successfully stopped handlers") - - # Disconnect again to ensure disconnection has occurred due to the issue mentioned above - logger.debug("Executing secondary disconnect...") - callback = EventedCallback() - self._mqtt_pipeline.disconnect(callback=callback) - handle_result(callback) - logger.debug("Successfully executed secondary disconnect") - - # It's also possible that in the (very short) time between stopping the handlers and - # the second disconnect, additional items were received (e.g. C2D Message) - # Currently, this isn't really possible to accurately check due to a - # race condition / thread timing issue with inboxes where we can't guarantee how many - # items are truly in them. - # This has always been true of this client, even before handlers. - # - # However, even if the race condition is addressed, that will only allow us to log that - # messages were lost. To actually fix the problem, IoTHub needs to support MQTT5 so that - # we can unsubscribe from receiving data. - - logger.info("Successfully disconnected from Hub") - - def update_sastoken(self, sastoken): - """ - Update the client's SAS Token used for authentication, then reauthorizes the connection. - - This API can only be used if the client was initially created with a SAS Token. - - :param str sastoken: The new SAS Token string for the client to use - - :raises: ValueError if the sastoken parameter is invalid - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be re-established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a re-establishing - the connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if the reauthorization - attempt times out. - :raises: :class:`azure.iot.device.exceptions.ClientError` if the client was not initially - created with a SAS token. - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - self._replace_user_supplied_sastoken(sastoken) - - # Reauthorize the connection - logger.info("Reauthorizing connection with Hub...") - callback = EventedCallback() - self._mqtt_pipeline.reauthorize_connection(callback=callback) - handle_result(callback) - # NOTE: Currently due to the MQTT3 implementation, the pipeline reauthorization will return - # after the disconnect. It does not wait for the reconnect to complete. This means that - # any errors that may occur as part of the connect will not return via this callback. - # They will instead go to the background exception handler. - - logger.info("Successfully reauthorized connection to Hub") - - def send_message(self, message): - """Sends a message to the default events endpoint on the Azure IoT Hub or Azure IoT Edge Hub instance. - - This is a synchronous event, meaning that this function will not return until the event - has been sent to the service and the service has acknowledged receipt of the event. - - If the connection to the service has not previously been opened by a call to connect, this - function will open the connection before sending the event. - - :param message: The actual message to send. Anything passed that is not an instance of the - Message class will be converted to Message object. - :type message: :class:`azure.iot.device.Message` or str - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - :raises: ValueError if the message fails size validation. - """ - if not isinstance(message, Message): - message = Message(message) - - if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - raise ValueError("Size of telemetry message can not exceed 256 KB.") - - logger.info("Sending message to Hub...") - - callback = EventedCallback() - self._mqtt_pipeline.send_message(message, callback=callback) - handle_result(callback) - - logger.info("Successfully sent message to Hub") - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_method_request_received property to set a handler instead", - ) - def receive_method_request(self, method_name=None, block=True, timeout=None): - """Receive a method request via the Azure IoT Hub or Azure IoT Edge Hub. - - :param str method_name: Optionally provide the name of the method to receive requests for. - If this parameter is not given, all methods not already being specifically targeted by - a different request to receive_method will be received. - :param bool block: Indicates if the operation should block until a request is received. - :param int timeout: Optionally provide a number of seconds until blocking times out. - - :returns: MethodRequest object representing the received method request, or None if - no method request has been received by the end of the blocking period. - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[pipeline_constant.METHODS]: - self._enable_feature(pipeline_constant.METHODS) - - method_inbox = self._inbox_manager.get_method_request_inbox(method_name) - - logger.info("Waiting for method request...") - try: - method_request = method_inbox.get(block=block, timeout=timeout) - logger.info("Received method request") - except InboxEmpty: - method_request = None - logger.info("Did not receive method request") - return method_request - - def send_method_response(self, method_response): - """Send a response to a method request via the Azure IoT Hub or Azure IoT Edge Hub. - - This is a synchronous event, meaning that this function will not return until the event - has been sent to the service and the service has acknowledged receipt of the event. - - If the connection to the service has not previously been opened by a call to connect, this - function will open the connection before sending the event. - - :param method_response: The MethodResponse to send. - :type method_response: :class:`azure.iot.device.MethodResponse` - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Sending method response to Hub...") - - callback = EventedCallback() - self._mqtt_pipeline.send_method_response(method_response, callback=callback) - handle_result(callback) - - logger.info("Successfully sent method response to Hub") - - def get_twin(self): - """ - Gets the device or module twin from the Azure IoT Hub or Azure IoT Edge Hub service. - - This is a synchronous call, meaning that this function will not return until the twin - has been retrieved from the service. - - :returns: Complete Twin as a JSON dict - :rtype: dict - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - if not self._mqtt_pipeline.feature_enabled[pipeline_constant.TWIN]: - self._enable_feature(pipeline_constant.TWIN) - - callback = EventedCallback(return_arg_name="twin") - self._mqtt_pipeline.get_twin(callback=callback) - twin = handle_result(callback) - - logger.info("Successfully retrieved twin") - return twin - - def patch_twin_reported_properties(self, reported_properties_patch): - """ - Update reported properties with the Azure IoT Hub or Azure IoT Edge Hub service. - - This is a synchronous call, meaning that this function will not return until the patch - has been sent to the service and acknowledged. - - If the service returns an error on the patch operation, this function will raise the - appropriate error. - - :param reported_properties_patch: Twin Reported Properties patch as a JSON dict - :type reported_properties_patch: dict - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - if not self._mqtt_pipeline.feature_enabled[pipeline_constant.TWIN]: - self._enable_feature(pipeline_constant.TWIN) - - callback = EventedCallback() - self._mqtt_pipeline.patch_twin_reported_properties( - patch=reported_properties_patch, callback=callback - ) - handle_result(callback) - - logger.info("Successfully patched twin") - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_twin_desired_properties_patch_received property to set a handler instead", - ) - def receive_twin_desired_properties_patch(self, block=True, timeout=None): - """ - Receive a desired property patch via the Azure IoT Hub or Azure IoT Edge Hub. - - This is a synchronous call, which means the following: - 1. If block=True, this function will block until one of the following happens: - * a desired property patch is received from the Azure IoT Hub or Azure IoT Edge Hub. - * the timeout period, if provided, elapses. If a timeout happens, this function will - raise a InboxEmpty exception - 2. If block=False, this function will return any desired property patches which may have - been received by the pipeline, but not yet returned to the application. If no - desired property patches have been received by the pipeline, this function will raise - an InboxEmpty exception - - :param bool block: Indicates if the operation should block until a request is received. - :param int timeout: Optionally provide a number of seconds until blocking times out. - - :returns: Twin Desired Properties patch as a JSON dict, or None if no patch has been - received by the end of the blocking period - :rtype: dict or None - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[pipeline_constant.TWIN_PATCHES]: - self._enable_feature(pipeline_constant.TWIN_PATCHES) - twin_patch_inbox = self._inbox_manager.get_twin_patch_inbox() - - logger.info("Waiting for twin patches...") - try: - patch = twin_patch_inbox.get(block=block, timeout=timeout) - logger.info("twin patch received") - except InboxEmpty: - logger.info("Did not receive twin patch") - return None - return patch - - -class IoTHubDeviceClient(GenericIoTHubClient, AbstractIoTHubDeviceClient): - """A synchronous device client that connects to an Azure IoT Hub instance.""" - - def __init__(self, mqtt_pipeline, http_pipeline): - """Initializer for a IoTHubDeviceClient. - - This initializer should not be called directly. - Instead, use one of the 'create_from_' classmethods to instantiate - - :param mqtt_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type mqtt_pipeline: :class:`azure.iot.device.iothub.pipeline.MQTTPipeline` - """ - super().__init__(mqtt_pipeline=mqtt_pipeline, http_pipeline=http_pipeline) - self._mqtt_pipeline.on_c2d_message_received = self._inbox_manager.route_c2d_message - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_message_received property to set a handler instead", - ) - def receive_message(self, block=True, timeout=None): - """Receive a message that has been sent from the Azure IoT Hub. - - :param bool block: Indicates if the operation should block until a message is received. - :param int timeout: Optionally provide a number of seconds until blocking times out. - - :returns: Message that was sent from the Azure IoT Hub, or None if - no method request has been received by the end of the blocking period. - :rtype: :class:`azure.iot.device.Message` or None - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[pipeline_constant.C2D_MSG]: - self._enable_feature(pipeline_constant.C2D_MSG) - c2d_inbox = self._inbox_manager.get_c2d_message_inbox() - - logger.info("Waiting for message from Hub...") - try: - message = c2d_inbox.get(block=block, timeout=timeout) - logger.info("Message received") - except InboxEmpty: - message = None - logger.info("No message received.") - return message - - def get_storage_info_for_blob(self, blob_name): - """Sends a POST request over HTTP to an IoTHub endpoint that will return information for uploading via the Azure Storage Account linked to the IoTHub your device is connected to. - - :param str blob_name: The name in string format of the blob that will be uploaded using the storage API. This name will be used to generate the proper credentials for Storage, and needs to match what will be used with the Azure Storage SDK to perform the blob upload. - - :returns: A JSON-like (dictionary) object from IoT Hub that will contain relevant information including: correlationId, hostName, containerName, blobName, sasToken. - """ - callback = EventedCallback(return_arg_name="storage_info") - self._http_pipeline.get_storage_info_for_blob(blob_name, callback=callback) - storage_info = handle_result(callback) - logger.info("Successfully retrieved storage_info") - return storage_info - - def notify_blob_upload_status( - self, correlation_id, is_success, status_code, status_description - ): - """When the upload is complete, the device sends a POST request to the IoT Hub endpoint with information on the status of an upload to blob attempt. This is used by IoT Hub to notify listening clients. - - :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. - :param bool is_success: A boolean that indicates whether the file was uploaded successfully. - :param int status_code: A numeric status code that is the status for the upload of the file to storage. - :param str status_description: A description that corresponds to the status_code. - """ - callback = EventedCallback() - self._http_pipeline.notify_blob_upload_status( - correlation_id=correlation_id, - is_success=is_success, - status_code=status_code, - status_description=status_description, - callback=callback, - ) - handle_result(callback) - logger.info("Successfully notified blob upload status") - - -class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): - """A synchronous module client that connects to an Azure IoT Hub or Azure IoT Edge instance.""" - - def __init__(self, mqtt_pipeline, http_pipeline): - """Initializer for a IoTHubModuleClient. - - This initializer should not be called directly. - Instead, use one of the 'create_from_' classmethods to instantiate - - :param mqtt_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type mqtt_pipeline: :class:`azure.iot.device.iothub.pipeline.MQTTPipeline` - :param http_pipeline: The pipeline used to connect to the IoTHub endpoint via HTTP. - :type http_pipeline: :class:`azure.iot.device.iothub.pipeline.HTTPPipeline` - """ - super().__init__(mqtt_pipeline=mqtt_pipeline, http_pipeline=http_pipeline) - self._mqtt_pipeline.on_input_message_received = self._inbox_manager.route_input_message - - def send_message_to_output(self, message, output_name): - """Sends an event/message to the given module output. - - These are outgoing events and are meant to be "output events". - - This is a synchronous event, meaning that this function will not return until the event - has been sent to the service and the service has acknowledged receipt of the event. - - If the connection to the service has not previously been opened by a call to connect, this - function will open the connection before sending the event. - - :param message: Message to send to the given output. Anything passed that is not an instance of the - Message class will be converted to Message object. - :type message: :class:`azure.iot.device.Message` or str - :param str output_name: Name of the output to send the event to. - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if connection attempt - times out - :raises: :class:`azure.iot.device.exceptions.NoConnectionError` if the client is not - connected (and there is no auto-connect enabled) - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - :raises: ValueError if the message fails size validation. - """ - if not isinstance(message, Message): - message = Message(message) - - if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - raise ValueError("Size of message can not exceed 256 KB.") - - message.output_name = output_name - - logger.info("Sending message to output:" + output_name + "...") - - callback = EventedCallback() - self._mqtt_pipeline.send_output_message(message, callback=callback) - handle_result(callback) - - logger.info("Successfully sent message to output: " + output_name) - - @deprecation.deprecated( - deprecated_in="2.3.0", - current_version=device_constant.VERSION, - details="We recommend that you use the .on_message_received property to set a handler instead", - ) - def receive_message_on_input(self, input_name, block=True, timeout=None): - """Receive an input message that has been sent from another Module to a specific input. - - :param str input_name: The input name to receive a message on. - :param bool block: Indicates if the operation should block until a message is received. - :param int timeout: Optionally provide a number of seconds until blocking times out. - - :returns: Message that was sent to the specified input, or None if - no method request has been received by the end of the blocking period. - """ - self._check_receive_mode_is_api() - - if not self._mqtt_pipeline.feature_enabled[pipeline_constant.INPUT_MSG]: - self._enable_feature(pipeline_constant.INPUT_MSG) - input_inbox = self._inbox_manager.get_input_message_inbox(input_name) - - logger.info("Waiting for input message on: " + input_name + "...") - try: - message = input_inbox.get(block=block, timeout=timeout) - logger.info("Input message received on: " + input_name) - except InboxEmpty: - message = None - logger.info("No input message received on: " + input_name) - return message - - def invoke_method(self, method_params, device_id, module_id=None): - """Invoke a method from your client onto a device or module client, and receive the response to the method call. - - :param dict method_params: Should contain a methodName (str), payload (str), - connectTimeoutInSeconds (int), responseTimeoutInSeconds (int). - :param str device_id: Device ID of the target device where the method will be invoked. - :param str module_id: Module ID of the target module where the method will be invoked. (Optional) - - :returns: method_result should contain a status, and a payload - :rtype: dict - """ - logger.info( - "Invoking {} method on {}{}".format(method_params["methodName"], device_id, module_id) - ) - callback = EventedCallback(return_arg_name="invoke_method_response") - self._http_pipeline.invoke_method( - device_id, method_params, callback=callback, module_id=module_id - ) - invoke_method_response = handle_result(callback) - logger.info("Successfully invoked method") - return invoke_method_response diff --git a/azure-iot-device/azure/iot/device/iothub/sync_handler_manager.py b/azure-iot-device/azure/iot/device/iothub/sync_handler_manager.py deleted file mode 100644 index f1af68664..000000000 --- a/azure-iot-device/azure/iot/device/iothub/sync_handler_manager.py +++ /dev/null @@ -1,456 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains the manager for handler methods used by the callback client""" -import logging -import threading -import abc -from azure.iot.device.common import handle_exceptions -from azure.iot.device.iothub.client_event import ( - CONNECTION_STATE_CHANGE, - NEW_SASTOKEN_REQUIRED, - BACKGROUND_EXCEPTION, -) -import concurrent.futures - -logger = logging.getLogger(__name__) - -# Receiver Handlers -MESSAGE = "_on_message_received" -METHOD = "_on_method_request_received" -TWIN_DP_PATCH = "_on_twin_desired_properties_patch_received" - -# Client Event Handler Runner -CLIENT_EVENT = "client_event" -# Client Event Names -client_events = [CONNECTION_STATE_CHANGE, NEW_SASTOKEN_REQUIRED, BACKGROUND_EXCEPTION] - - -class HandlerManagerException(Exception): - """An exception raised by a HandlerManager""" - - pass - - -class HandlerRunnerKillerSentinel(object): - """An object that functions according to the sentinel design pattern. - Insert into an Inbox in order to indicate that the Handler Runner associated with that - Inbox should be stopped. - """ - - pass - - -class AbstractHandlerManager(abc.ABC): - """Partial class that defines handler manager functionality shared between sync/async""" - - def __init__(self, inbox_manager): - self._inbox_manager = inbox_manager - - self._receiver_handler_runners = { - MESSAGE: None, - METHOD: None, - TWIN_DP_PATCH: None, - } - self._client_event_runner = None - - # Receiver handlers (Each will have it's own runner) - self._on_message_received = None - self._on_method_request_received = None - self._on_twin_desired_properties_patch_received = None - - # Client Event handlers (Share a single Client Event runner) - self._on_connection_state_change = None - self._on_new_sastoken_required = None - self._on_background_exception = None - - # As mentioned above, Receiver handlers each get their own runner. This is because it is - # reasonably possible that many receives (and many different types of receives) can be - # happening in a very short time. This is because we want to be able to support processing - # multiple receives simultaneously, but also protect different receive invocations from - # being slowed by a poorly written handler of a different type. Thus each Receiver handler - # gets its own unique runner that cannot end up blocked by other handlers. - # - # Client Events on the other hand are generated by the client itself rather than - # unsolicited data received over the wire. As these are relatively uncommon, we don't - # expect multiple Client Events of the same type to really start queueing up, so there - # isn't much justification for each type of event getting it's own dedicated runner. - # Furthermore, there is much less concern over one inefficient or slow handler blocking - # execution of others due to this infrequency. - # - # However, there are other differences between these handler classes as well. Receiver - # handlers will always be invoked with a single argument - the received data structure. - # Client Event handlers on the other hand are more flexible - they may be invoked with - # different numbers of arguments depending on the handler, from none to multiple. This - # is to keep design space open for more complex Client Events in the future. - # - # Finally, it is possible to stop ONLY the receiver handlers via the .stop() method - # if desired. This is useful because while receiver handlers should be stopped when the - # client disconnects, Client Events can still occur while the client is disconnected - # as they are propagated from the client itself rather than received data. - - def _get_inbox_for_receive_handler(self, handler_name): - """Retrieve the inbox relevant to the handler""" - if handler_name == METHOD: - return self._inbox_manager.get_method_request_inbox() - elif handler_name == TWIN_DP_PATCH: - return self._inbox_manager.get_twin_patch_inbox() - elif handler_name == MESSAGE: - return self._inbox_manager.get_unified_message_inbox() - else: - return None - - def _get_handler_for_client_event(self, event_name): - """Retrieve the handler relevant to the event""" - if event_name == NEW_SASTOKEN_REQUIRED: - return self._on_new_sastoken_required - elif event_name == CONNECTION_STATE_CHANGE: - return self._on_connection_state_change - elif event_name == BACKGROUND_EXCEPTION: - return self._on_background_exception - else: - return None - - @abc.abstractmethod - def _receiver_handler_runner(self, inbox, handler_name): - """Run infinite loop that waits for an inbox to receive an object from it, then calls - the handler with that object - """ - pass - - @abc.abstractmethod - def _client_event_handler_runner(self, handler_name): - """Run infinite loop that waits for the client event inbox to receive an event from it, - then calls the handler that corresponds to that event - """ - pass - - @abc.abstractmethod - def _start_handler_runner(self, handler_name): - """Create, and store a handler runner""" - pass - - @abc.abstractmethod - def _stop_receiver_handler_runner(self, handler_name): - """Cancel and remove a handler runner""" - pass - - @abc.abstractmethod - def _stop_client_event_handler_runner(self): - """Cancel the client event handler runner""" - pass - - def _generic_receiver_handler_setter(self, handler_name, new_handler): - """Set a handler""" - curr_handler = getattr(self, handler_name) - if new_handler is not None and curr_handler is None: - # Create runner, set handler - logger.debug("Creating new handler runner for handler: {}".format(handler_name)) - setattr(self, handler_name, new_handler) - self._start_handler_runner(handler_name) - elif new_handler is None and curr_handler is not None: - # Cancel runner, remove handler - logger.debug("Removing handler runner for handler: {}".format(handler_name)) - self._stop_receiver_handler_runner(handler_name) - setattr(self, handler_name, new_handler) - else: - # Update handler, no need to change runner - logger.debug("Updating set handler: {}".format(handler_name)) - setattr(self, handler_name, new_handler) - - @staticmethod - def _generate_callback_for_handler(handler_name): - """Define a callback that can handle errors during handler execution""" - - def handler_callback(future): - try: - e = future.exception(timeout=0) - except Exception as raised_e: - # This shouldn't happen because cancellation or timeout shouldn't occur... - # But just in case... - new_err = HandlerManagerException( - "HANDLER ({}): Unable to retrieve exception data from incomplete invocation".format( - handler_name - ) - ) - new_err.__cause__ = raised_e - handle_exceptions.handle_background_exception(new_err) - else: - if e: - new_err = HandlerManagerException( - "HANDLER ({}): Error during invocation".format(handler_name), - ) - new_err.__cause__ = e - handle_exceptions.handle_background_exception(new_err) - else: - logger.debug( - "HANDLER ({}): Successfully completed invocation".format(handler_name) - ) - - return handler_callback - - def stop(self, receiver_handlers_only=False): - """Stop the process of invoking handlers in response to events. - All pending items will be handled prior to stoppage. - """ - # Stop receiver handlers - for handler_name in self._receiver_handler_runners: - if self._receiver_handler_runners[handler_name] is not None: - self._stop_receiver_handler_runner(handler_name) - - # Stop the client event handler (if instructed) - if not receiver_handlers_only and self._client_event_runner is not None: - self._stop_client_event_handler_runner() - - def ensure_running(self): - """Ensure the process of invoking handlers in response to events is running""" - # Ensure any receiver handler set on the manager has a corresponding handler runner running - for handler_name in self._receiver_handler_runners: - if ( - self._receiver_handler_runners[handler_name] is None - and getattr(self, handler_name) is not None - ): - self._start_handler_runner(handler_name) - - # Ensure client event handler runner is running if at least one client event handler is set - # on the manager - if self._client_event_runner is None: - for event in client_events: - handler = self._get_handler_for_client_event(event) - if handler is not None: - self._start_handler_runner(CLIENT_EVENT) - break - - # ~~~Receiver Handlers~~~ - # Setting a receiver handler will start a dedicated runner for that handler - # Removing a receiver handler will stop the dedicated runner for that handler - @property - def on_message_received(self): - return self._on_message_received - - @on_message_received.setter - def on_message_received(self, value): - self._generic_receiver_handler_setter(MESSAGE, value) - - @property - def on_method_request_received(self): - return self._on_method_request_received - - @on_method_request_received.setter - def on_method_request_received(self, value): - self._generic_receiver_handler_setter(METHOD, value) - - @property - def on_twin_desired_properties_patch_received(self): - return self._on_twin_desired_properties_patch_received - - @on_twin_desired_properties_patch_received.setter - def on_twin_desired_properties_patch_received(self, value): - self._generic_receiver_handler_setter(TWIN_DP_PATCH, value) - - # ~~~Client Event Handlers~~~ - # Setting any client event handler will start the shared client event handler runner - # Removing handlers will NOT stop the client event handler runner - you must use .stop() - # Stopping when all client event handlers are removed could be added if necessary. - @property - def on_connection_state_change(self): - return self._on_connection_state_change - - @on_connection_state_change.setter - def on_connection_state_change(self, value): - self._on_connection_state_change = value - if self._client_event_runner is None: - self._start_handler_runner(CLIENT_EVENT) - - @property - def on_new_sastoken_required(self): - return self._on_new_sastoken_required - - @on_new_sastoken_required.setter - def on_new_sastoken_required(self, value): - self._on_new_sastoken_required = value - if self._client_event_runner is None: - self._start_handler_runner(CLIENT_EVENT) - - @property - def on_background_exception(self): - return self._on_background_exception - - @on_background_exception.setter - def on_background_exception(self, value): - self._on_background_exception = value - if self._client_event_runner is None: - self._start_handler_runner(CLIENT_EVENT) - - # ~~~Other Properties~~~ - @property - def handling_client_events(self): - """Indicates if the HandlerManager is currently capable of resolving ClientEvents""" - # This client event runner is only running if at least one handler for client events has - # been set. If none have been set, it is dangerous to add items to the client event inbox - # as none will ever be retrieved due to no runner process occurring, thus the need for this - # check. - # - # The ideal solution would be to always keep the client event runner running, but this - # could break older customer code due to older APIs on the customer-facing clients. It is - # unfortunate that something related to an API has seeped into this internal and ideally - # isolated module, but the needs of the client design have influenced the design of this - # manager (by only starting the runner when a handler is set), so the mitigation must also - # be located in this module. - if self._client_event_runner is None: - return False - else: - return True - - -class SyncHandlerManager(AbstractHandlerManager): - """Handler manager for use with synchronous clients""" - - def _receiver_handler_runner(self, inbox, handler_name): - """Run infinite loop that waits for an inbox to receive an object from it, then calls - the handler with that object - """ - logger.debug("HANDLER RUNNER ({}): Starting runner".format(handler_name)) - _handler_callback = self._generate_callback_for_handler(handler_name) - - # Run the handler in a threadpool, so that it cannot block other handlers (from a different task), - # or the main client thread. The number of worker threads forms an upper bound on how many instances - # of the same handler can be running simultaneously. - tpe = concurrent.futures.ThreadPoolExecutor(max_workers=4) - while True: - handler_arg = inbox.get() - if isinstance(handler_arg, HandlerRunnerKillerSentinel): - # Exit the runner when a HandlerRunnerKillerSentinel is found - logger.debug( - "HANDLER RUNNER ({}): HandlerRunnerKillerSentinel found in inbox. Exiting.".format( - handler_name - ) - ) - tpe.shutdown() - break - # NOTE: we MUST use getattr here using the handler name, as opposed to directly passing - # the handler in order for the handler to be able to be updated without cancelling - # the running task created for this coroutine - handler = getattr(self, handler_name) - logger.debug("HANDLER RUNNER ({}): Invoking handler".format(handler_name)) - fut = tpe.submit(handler, handler_arg) - fut.add_done_callback(_handler_callback) - # Free up this object so the garbage collector can free it if necessary. If we don't - # do this, we end up keeping this object alive until the next event arrives, which - # might be a long time. Tests would flag this as a memory leak if that happened. - del handler_arg - - def _client_event_handler_runner(self): - """Run infinite loop that waits for the client event inbox to receive an event from it, - then calls the handler that corresponds to that event - """ - logger.debug("HANDLER RUNNER (CLIENT EVENT): Starting runner") - _handler_callback = self._generate_callback_for_handler("CLIENT_EVENT") - - tpe = concurrent.futures.ThreadPoolExecutor(max_workers=4) - event_inbox = self._inbox_manager.get_client_event_inbox() - while True: - event = event_inbox.get() - if isinstance(event, HandlerRunnerKillerSentinel): - # Exit the runner when a HandlerRunnerKillerSentinel is found - logger.debug( - "HANDLER RUNNER (CLIENT EVENT): HandlerRunnerKillerSentinel found in event queue. Exiting." - ) - tpe.shutdown() - break - handler = self._get_handler_for_client_event(event.name) - if handler is not None: - logger.debug( - "HANDLER RUNNER (CLIENT EVENT): {} event received. Invoking {} handler".format( - event, handler - ) - ) - fut = tpe.submit(handler, *event.args_for_user) - fut.add_done_callback(_handler_callback) - # Free up this object so the garbage collector can free it if necessary. If we don't - # do this, we end up keeping this object alive until the next event arrives, which - # might be a long time. Tests would flag this as a memory leak if that happened. - del event - else: - logger.debug( - "No handler for event {} set. Skipping handler invocation".format(event) - ) - - def _start_handler_runner(self, handler_name): - """Start and store a handler runner thread""" - # Client Event handler flow - if handler_name == CLIENT_EVENT: - if self._client_event_runner is not None: - # This branch of code should NOT be reachable due to checks prior to the invocation - # of this method. The branch exists for safety. - raise HandlerManagerException( - "Cannot create thread for handler runner: {}. Runner thread already exists".format( - handler_name - ) - ) - # Client events share a handler - thread = threading.Thread(target=self._client_event_handler_runner) - # Store the thread - self._client_event_runner = thread - - # Receiver handler flow - else: - if self._receiver_handler_runners[handler_name] is not None: - # This branch of code should NOT be reachable due to checks prior to the invocation - # of this method. The branch exists for safety. - raise HandlerManagerException( - "Cannot create thread for handler runner: {}. Runner thread already exists".format( - handler_name - ) - ) - inbox = self._get_inbox_for_receive_handler(handler_name) - # Each receiver handler gets its own runner - thread = threading.Thread( - target=self._receiver_handler_runner, args=[inbox, handler_name] - ) - # Store the thread - self._receiver_handler_runners[handler_name] = thread - - # NOTE: It would be nice to have some kind of mechanism for making sure this thread - # doesn't crash or raise errors, but it would require significant extra infrastructure - # and an exception in here isn't supposed to happen anyway. Perhaps it could be added - # later if truly necessary - thread.daemon = True # Don't block program exit - thread.start() - - def _stop_receiver_handler_runner(self, handler_name): - """Stop and remove a handler runner thread. - All pending items in the corresponding inbox will be handled by the handler before stoppage. - """ - logger.debug( - "Adding HandlerRunnerKillerSentinel to inbox corresponding to {} handler runner".format( - handler_name - ) - ) - inbox = self._get_inbox_for_receive_handler(handler_name) - inbox.put(HandlerRunnerKillerSentinel()) - - # Wait for Handler Runner to end due to the sentinel - logger.debug("Waiting for {} handler runner to exit...".format(handler_name)) - thread = self._receiver_handler_runners[handler_name] - thread.join() - self._receiver_handler_runners[handler_name] = None - logger.debug("Handler runner for {} has been stopped".format(handler_name)) - - def _stop_client_event_handler_runner(self): - """Stop and remove a handler runner thread. - All pending items in the client event queue will be handled by handlers (if they exist) - before stoppage. - """ - logger.debug("Adding HandlerRunnerKillerSentinel to client event queue") - event_inbox = self._inbox_manager.get_client_event_inbox() - event_inbox.put(HandlerRunnerKillerSentinel()) - - # Wait for Handler Runner to end due to the stop command - logger.debug("Waiting for client event handler runner to exit...") - thread = self._client_event_runner - thread.join() - self._client_event_runner = None - logger.debug("Handler runner for client events has been stopped") diff --git a/azure-iot-device/azure/iot/device/iothub/sync_inbox.py b/azure-iot-device/azure/iot/device/iothub/sync_inbox.py deleted file mode 100644 index c60d66ba6..000000000 --- a/azure-iot-device/azure/iot/device/iothub/sync_inbox.py +++ /dev/null @@ -1,121 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains an Inbox class for use with a synchronous client.""" - -import queue -import abc - - -class InboxEmpty(Exception): - pass - - -class AbstractInbox(abc.ABC): - """Abstract Base Class for Inbox. - - Holds generic incoming data for a client. - - All methods, when implemented, should be threadsafe. - """ - - @abc.abstractmethod - def put(self, item): - """Put an item into the Inbox. - - Implementation MUST be a synchronous function. - Only to be used by the InboxManager. - - :param item: The item to put in the Inbox. - """ - pass - - @abc.abstractmethod - def get(self): - """Remove and return an item from the inbox. - - Implementation should have the capability to block until an item is available. - Implementation can be a synchronous function or an asynchronous coroutine. - - :returns: An item from the Inbox. - """ - pass - - @abc.abstractmethod - def empty(self): - """Returns True if the inbox is empty, False otherwise - - :returns: Boolean indicating if the inbox is empty - """ - pass - - @abc.abstractmethod - def clear(self): - """Remove all items from the inbox.""" - pass - - -class SyncClientInbox(AbstractInbox): - """Holds generic incoming data for a synchronous client. - - All methods implemented in this class are threadsafe. - """ - - def __init__(self): - """Initializer for SyncClientInbox""" - self._queue = queue.Queue() - - def __contains__(self, item): - """Return True if item is in Inbox, False otherwise""" - with self._queue.mutex: - return item in self._queue.queue - - def put(self, item): - """Put an item into the inbox. - - Only to be used by the InboxManager. - - :param item: The item to put in the inbox. - """ - self._queue.put(item) - - def get(self, block=True, timeout=None): - """Remove and return an item from the inbox. - - :param bool block: Indicates if the operation should block until an item is available. - Default True. - :param int timeout: Optionally provide a number of seconds until blocking times out. - - :raises: InboxEmpty if timeout occurs because the inbox is empty - :raises: InboxEmpty if inbox is empty in non-blocking mode - - :returns: An item from the Inbox - """ - try: - return self._queue.get(block=block, timeout=timeout) - except queue.Empty: - raise InboxEmpty("Inbox is empty") - - def empty(self): - """Returns True if the inbox is empty, False otherwise. - - Note that there is a race condition here, and this may not be accurate as the queue size - may change while this operation is occurring. - - :returns: Boolean indicating if the inbox is empty - """ - return self._queue.empty() - - def join(self): - """Block until all items in the inbox have been gotten and processed. - - Only really used for test code. - """ - return self._queue.join() - - def clear(self): - """Remove all items from the inbox.""" - with self._queue.mutex: - self._queue.queue.clear() diff --git a/azure-iot-device/azure/iot/device/iothub_http_client.py b/azure-iot-device/azure/iot/device/iothub_http_client.py new file mode 100644 index 000000000..3c48d608e --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub_http_client.py @@ -0,0 +1,277 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import aiohttp +import asyncio +import logging +import urllib.parse +from typing import Optional, cast +from .custom_typing import DirectMethodParameters, DirectMethodResult, StorageInfo +from . import exceptions as exc +from . import config, constant, user_agent +from . import http_path_iothub as http_path + +logger = logging.getLogger(__name__) + +# Header Definitions +HEADER_AUTHORIZATION = "Authorization" +HEADER_EDGE_MODULE_ID = "x-ms-edge-moduleId" +HEADER_USER_AGENT = "User-Agent" + +# Query parameter definitions +PARAM_API_VERISON = "api-version" + +# Other definitions +HTTP_TIMEOUT = 10 + +# NOTE: Outstanding items in this module: +# TODO: document aiohttp exceptions that can be raised +# TODO: URL Encoding logic +# TODO: Proxy support +# TODO: Should direct method responses be a DirectMethodResponse object? If so, what is the rid? +# See specific inline commentary for more details on what is required + + +# NOTE: aiohttp 3.x is bugged on Windows on Python 3.8.x - 3.10.6 +# If running the application using asyncio.run(), there will be an issue with the Event Loop +# raising a spurious RuntimeError on application exit. +# +# Windows Event Loops are notoriously tricky to deal with. This issue stems from the use of the +# default ProactorEventLoop, and can be mitigated by switching to a SelectorEventLoop, but +# we as SDK developers really ought not be modifying the end user's event loop, or monkeypatching +# error suppression into it. Furthermore, switching to a SelectorEvenLoop has some degradation of +# functionality. +# +# The best course of action is for the end user to use loop.run_until_complete() instead of +# asyncio.run() in their application, as this will allow for better cleanup. +# +# Eventually when there is an aiohttp 4.x released, this bug will be eliminated from all versions +# of Python, but until then, there's not much to be done about it. +# +# See: https://github.com/aio-libs/aiohttp/issues/4324, as well as many, many other similar issues +# for more details. + + +class IoTHubHTTPClient: + def __init__(self, client_config: config.IoTHubClientConfig) -> None: + """Instantiate the client + + :param client_config: The config object for the client + :type client_config: :class:`IoTHubClientConfig` + """ + self._device_id = client_config.device_id + self._module_id = client_config.module_id + self._edge_module_id = _format_edge_module_id(self._device_id, self._module_id) + self._user_agent_string = user_agent.get_iothub_user_agent() + client_config.product_info + + # TODO: add proxy support + # Doing so will require building a custom "Connector" that can be injected into the + # Session object. There are many examples around online. + # The built in per-request proxy of aiohttp is only partly functional, so I decided to + # not even bother implementing it, if it only does half the job. + if client_config.proxy_options: + # TODO: these warnings should probably be at API level + logger.warning("Proxy use with .invoke_direct_method() not supported") + logger.warning("Proxy use with .get_storage_info_for_blob() not supported") + logger.warning("Proxy use with .notify_blob_upload_status() not supported") + + self._session = _create_client_session(client_config.hostname) + self._ssl_context = client_config.ssl_context + self._sastoken_provider = client_config.sastoken_provider + + async def shutdown(self): + """Shut down the client + + Invoke only when complete finished with the client for graceful exit. + """ + await asyncio.shield(self._session.close()) + # Wait 250ms for the underlying SSL connections to close + # See: https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown + await asyncio.sleep(0.25) + + async def invoke_direct_method( + self, + *, + device_id: str, + module_id: Optional[str] = None, + method_params: DirectMethodParameters + ) -> DirectMethodResult: + """Send a request to invoke a direct method on a target device or module + + :param str device_id: The target device ID + :param str module_id: The target module ID + :param dict method_params: The parameters for the direct method invocation + + :returns: A dictionary containing a status and payload reported by the target device + :rtype: dict + + :raises: :class:`IoTHubClientError` if not using an IoT Edge Module + :raises: :class:`IoTHubClientError` if the direct method response cannot be parsed + :raises: :class:`IoTEdgeError` if IoT Edge responds with failure + """ + if not self._edge_module_id: + # NOTE: The Edge Module ID will be exist for any Module, it doesn't actually indicate + # if it is an Edge Module or not. There's no way to tell, unfortunately. + raise exc.IoTHubClientError(".invoke_direct_method() only available for Edge Modules") + + path = http_path.get_direct_method_invoke_path(device_id, module_id) + query_params = {PARAM_API_VERISON: constant.IOTHUB_API_VERSION} + # NOTE: Other headers are auto-generated by aiohttp + headers = { + HEADER_USER_AGENT: urllib.parse.quote_plus(self._user_agent_string), + HEADER_EDGE_MODULE_ID: self._edge_module_id, # TODO: I assume this isn't supposed to be URI encoded just like in MQTT? + } + # If using SAS auth, pass the auth header + if self._sastoken_provider: + headers[HEADER_AUTHORIZATION] = str(self._sastoken_provider.get_current_sastoken()) + + logger.debug( + "Sending direct method invocation request to {device_id}/{module_id}".format( + device_id=device_id, module_id=module_id + ) + ) + async with self._session.post( + url=path, + json=method_params, + params=query_params, + headers=headers, + ssl=self._ssl_context, + ) as response: + + if response.status >= 300: + logger.error("Received failure response from IoT Edge for direct method invocation") + raise exc.IoTEdgeError( + "IoT Edge responded to direct method invocation with a failed status ({status}) - {reason}".format( + status=response.status, reason=response.reason + ) + ) + else: + logger.debug( + "Successfully received response from IoT Edge for direct method invocation" + ) + dm_result = cast(DirectMethodResult, await response.json()) + + return dm_result + + async def get_storage_info_for_blob(self, *, blob_name: str) -> StorageInfo: + """Request information for uploading blob file via the Azure Storage SDK + + :param str blob_name: The name of the blob that will be uploaded to the Azure Storage SDK + + :returns: The Azure Storage information returned by IoTHub + :rtype: dict + + :raises: :class:`IoTHubClientError` if not using a Device + :raises: :class:`IoTHubError` if IoTHub responds with failure + """ + if self._module_id: + raise exc.IoTHubClientError(".get_storage_info_for_blob() only available for Devices") + + path = http_path.get_storage_info_for_blob_path( + self._device_id + ) # TODO: is this bad that this is encoding? aiohttp encodes automatically + query_params = {PARAM_API_VERISON: constant.IOTHUB_API_VERSION} + data = {"blobName": blob_name} + # NOTE: Other headers are auto-generated by aiohttp + headers = {HEADER_USER_AGENT: urllib.parse.quote_plus(self._user_agent_string)} + # If using SAS auth, pass the auth header + if self._sastoken_provider: + headers[HEADER_AUTHORIZATION] = str(self._sastoken_provider.get_current_sastoken()) + + logger.debug("Sending storage info request to IoTHub...") + async with self._session.post( + url=path, + json=data, + params=query_params, + headers=headers, + ssl=self._ssl_context, + ) as response: + + if response.status >= 300: + logger.error("Received failure response from IoTHub for storage info request") + raise exc.IoTHubError( + "IoTHub responded to storage info request with a failed status ({status}) - {reason}".format( + status=response.status, reason=response.reason + ) + ) + else: + logger.debug("Successfully received response from IoTHub for storage info request") + storage_info = cast(StorageInfo, await response.json()) + + return storage_info + + async def notify_blob_upload_status( + self, *, correlation_id: str, is_success: bool, status_code: int, status_description: str + ) -> None: + """Notify IoTHub of the result of a Azure Storage SDK blob upload + + :param str correlation_id: ID for the blob upload + :param bool is_success: Indicates whether the file was uploaded successfully + :param int status_code: A numeric status code for the file upload + :param str status_description: A description that corresponds to the status_code + + :raises: :class:`IoTHubClientError` if not using a Device + :raises: :class:`IoTHubError` if IoTHub responds with failure + """ + if self._module_id: + raise exc.IoTHubClientError(".notify_blob_upload_status() only available for Devices") + + path = http_path.get_notify_blob_upload_status_path(self._device_id) + query_params = {PARAM_API_VERISON: constant.IOTHUB_API_VERSION} + data = { + "correlationId": correlation_id, + "isSuccess": is_success, + "statusCode": status_code, + "statusDescription": status_description, + } + # NOTE: Other headers are auto-generated by aiohttp + headers = {HEADER_USER_AGENT: urllib.parse.quote_plus(self._user_agent_string)} + # If using SAS auth, pass the auth header + if self._sastoken_provider: + headers[HEADER_AUTHORIZATION] = str(self._sastoken_provider.get_current_sastoken()) + + logger.debug("Sending blob upload notification to IoTHub...") + async with self._session.post( + url=path, + json=data, + params=query_params, + headers=headers, + ssl=self._ssl_context, + ) as response: + + if response.status >= 300: + logger.error("Received failure response from IoTHub for blob upload notification") + raise exc.IoTHubError( + "IoTHub responded to blob upload notification with a failed status ({status}) - {reason}".format( + status=response.status, reason=response.reason + ) + ) + else: + logger.debug( + "Successfully received from response from IoTHub for blob upload notification" + ) + + return None + + +def _format_edge_module_id(device_id: str, module_id: Optional[str]) -> Optional[str]: + """Returns the edge module identifier""" + if module_id: + return "{device_id}/{module_id}".format(device_id=device_id, module_id=module_id) + else: + return None + + +def _create_client_session(hostname: str) -> aiohttp.ClientSession: + """Create and return a aiohttp ClientSession object""" + base_url = "https://{hostname}".format(hostname=hostname) + timeout = aiohttp.ClientTimeout(total=HTTP_TIMEOUT) + session = aiohttp.ClientSession(base_url=base_url, timeout=timeout) + logger.debug( + "Creating HTTP Session for {url} with timeout of {timeout}".format( + url=base_url, timeout=timeout.total + ) + ) + return session diff --git a/azure-iot-device/azure/iot/device/iothub_mqtt_client.py b/azure-iot-device/azure/iot/device/iothub_mqtt_client.py new file mode 100644 index 000000000..d59ef707d --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub_mqtt_client.py @@ -0,0 +1,676 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import json +import logging +import urllib.parse +from typing import Callable, Optional, AsyncGenerator, TypeVar +from .custom_typing import TwinPatch, Twin +from . import config, constant, user_agent, models +from . import exceptions as exc +from . import request_response as rr +from . import mqtt_client as mqtt +from . import mqtt_topic_iothub as mqtt_topic + +# TODO: update docstrings with correct class paths once repo structured better +# TODO: If we're truly done with keeping SAS credentials fresh, we don't need to use SasTokenProvider, +# and we could just simply use a single token or generator instead. + +logger = logging.getLogger(__name__) + +DEFAULT_RECONNECT_INTERVAL: int = 10 + +_T = TypeVar("_T") + + +class IoTHubMQTTClient: + def __init__( + self, + client_config: config.IoTHubClientConfig, + ) -> None: + """Instantiate the client + + :param client_config: The config object for the client + :type client_config: :class:`IoTHubClientConfig` + """ + # Identity + self._device_id = client_config.device_id + self._module_id = client_config.module_id + self._client_id = _format_client_id(self._device_id, self._module_id) + self._username = _format_username( + hostname=client_config.hostname, + client_id=self._client_id, + product_info=client_config.product_info, + ) + + # SAS (Optional) + self._sastoken_provider = client_config.sastoken_provider + + # MQTT Configuration + self._mqtt_client = _create_mqtt_client(self._client_id, client_config) + # NOTE: credentials are set upon `.start()` + + # Add filters for receive topics delivering data used internally + twin_response_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + self._mqtt_client.add_incoming_message_filter(twin_response_topic) + + # Create generators for receive topics delivering data used externally + # (Implicitly adding filters for these topics as well) + self._incoming_input_messages: Optional[AsyncGenerator[models.Message, None]] = None + self._incoming_c2d_messages: Optional[AsyncGenerator[models.Message, None]] = None + self._incoming_direct_method_requests: AsyncGenerator[models.DirectMethodRequest, None] + self._incoming_twin_patches: AsyncGenerator[TwinPatch, None] + if self._module_id: + self._incoming_input_messages = self._create_incoming_data_generator( + topic=mqtt_topic.get_input_topic_for_subscribe(self._device_id, self._module_id), + transform_fn=_create_iothub_message_from_mqtt_message, + ) + else: + self._incoming_c2d_messages = self._create_incoming_data_generator( + topic=mqtt_topic.get_c2d_topic_for_subscribe(self._device_id), + transform_fn=_create_iothub_message_from_mqtt_message, + ) + self._incoming_direct_method_requests = self._create_incoming_data_generator( + topic=mqtt_topic.get_direct_method_request_topic_for_subscribe(), + transform_fn=_create_direct_method_request_from_mqtt_message, + ) + self._incoming_twin_patches = self._create_incoming_data_generator( + topic=mqtt_topic.get_twin_patch_topic_for_subscribe(), + transform_fn=_create_twin_patch_from_mqtt_message, + ) + + # Internal request/response infrastructure + self._request_ledger = rr.RequestLedger() + self._twin_responses_enabled = False + + # Background Tasks (Will be set upon `.start()`) + self._process_twin_responses_bg_task: Optional[asyncio.Task[None]] = None + + def _create_incoming_data_generator( + self, topic: str, transform_fn: Callable[[mqtt.MQTTMessage], _T] + ) -> AsyncGenerator[_T, None]: + """Return a generator for incoming MQTT data on a given topic, yielding a transformation + of that data via the given transform function""" + self._mqtt_client.add_incoming_message_filter(topic) + incoming_mqtt_messages = self._mqtt_client.get_incoming_message_generator(topic) + + async def generator() -> AsyncGenerator[_T, None]: + async for mqtt_message in incoming_mqtt_messages: + try: + yield transform_fn(mqtt_message) + mqtt_message = None + except asyncio.CancelledError: + # NOTE: In Python 3.7 this isn't a BaseException, so we must catch and re-raise + # NOTE: This shouldn't ever happen since none of the transform_fns should be + # doing async invocations, but can't hurt to have this for future-proofing. + raise + except Exception as e: + # TODO: background exception logging improvements (e.g. stacktrace) + logger.error("Failure transforming MQTTMessage: {}".format(e)) + logger.warning("Dropping MQTTMessage that could not be transformed") + + return generator() + + async def _enable_twin_responses(self) -> None: + """Enable receiving of twin responses (for twin requests, or twin patches) from IoTHub""" + logger.debug("Enabling receive of twin responses...") + topic = mqtt_topic.get_twin_response_topic_for_subscribe() + await self._mqtt_client.subscribe(topic) + self._twin_responses_enabled = True + logger.debug("Twin responses receive enabled") + + async def _process_twin_responses(self) -> None: + """Run indefinitely, matching twin responses with request ID""" + logger.debug("Starting the 'process_twin_responses' background task") + twin_response_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + twin_responses = self._mqtt_client.get_incoming_message_generator(twin_response_topic) + + async for mqtt_message in twin_responses: + try: + request_id = mqtt_topic.extract_request_id_from_twin_response_topic( + mqtt_message.topic + ) + status_code = int( + mqtt_topic.extract_status_code_from_twin_response_topic(mqtt_message.topic) + ) + # NOTE: We don't know what the content of the body is until we match the rid, so don't + # do more than just decode it here - leave interpreting the string to the coroutine + # waiting for the response. + response_body = mqtt_message.payload.decode("utf-8") + logger.debug("Twin response received (rid: {})".format(request_id)) + response = rr.Response( + request_id=request_id, status=status_code, body=response_body + ) + except Exception as e: + logger.error( + "Unexpected error ({}) while translating Twin response. Dropping.".format(e) + ) + # NOTE: In this situation the operation waiting for the response that we failed to + # receive will hang. This isn't the end of the world, since it can be cancelled, + # but if we really wanted to smooth this out, we could cancel the pending operation + # based on the request id (assuming getting the request id is not what failed). + # But for now, that's probably overkill, especially since this path ideally should + # never happen, because we would like to assume IoTHub isn't sending malformed data + continue + try: + await self._request_ledger.match_response(response) + except asyncio.CancelledError: + # NOTE: In Python 3.7 this isn't a BaseException, so we must catch and re-raise + raise + except KeyError: + # NOTE: This should only happen in edge cases involving cancellation of + # in-flight operations + logger.warning( + "Twin response (rid: {}) does not match any request".format(request_id) + ) + except Exception as e: + logger.error( + "Unexpected error ({}) while matching Twin response (rid: {}). Dropping response".format( + e, request_id + ) + ) + + async def start(self) -> None: + """Start up the client. + + - Must be invoked before any other methods. + - If already started, will not (meaningfully) do anything. + """ + # Set credentials + if self._sastoken_provider: + logger.debug("Using SASToken as password") + password = str(self._sastoken_provider.get_current_sastoken()) + else: + logger.debug("No password used") + password = None + self._mqtt_client.set_credentials(self._username, password) + # Start background tasks + if not self._process_twin_responses_bg_task: + self._process_twin_responses_bg_task = asyncio.create_task( + self._process_twin_responses() + ) + + async def stop(self) -> None: + """Stop the client. + + - Must be invoked when done with the client for graceful exit. + - If already stopped, will not do anything. + - Cannot be cancelled - if you try, the client will still fully shut down as much as + possible, although CancelledError will still be raised. + """ + cancelled_tasks = [] + logger.debug("Stopping IoTHubMQTTClient...") + + if self._process_twin_responses_bg_task: + logger.debug("Cancelling 'process_twin_responses' background task") + self._process_twin_responses_bg_task.cancel() + cancelled_tasks.append(self._process_twin_responses_bg_task) + self._process_twin_responses_bg_task = None + + results = await asyncio.gather( + *cancelled_tasks, asyncio.shield(self.disconnect()), return_exceptions=True + ) + for result in results: + # NOTE: Need to specifically exclude asyncio.CancelledError because it is not a + # BaseException in Python 3.7 + if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError): + raise result + + async def connect(self) -> None: + """Connect to IoTHub + + :raises: MQTTConnectionFailedError if there is a failure connecting + """ + # Connect + logger.debug("Connecting to IoTHub...") + await self._mqtt_client.connect() + logger.debug("Connect succeeded") + + async def disconnect(self) -> None: + """Disconnect from IoTHub""" + logger.debug("Disconnecting from IoTHub...") + await self._mqtt_client.disconnect() + logger.debug("Disconnect succeeded") + + async def wait_for_disconnect(self) -> Optional[exc.MQTTConnectionDroppedError]: + """Block until disconnection and return the cause, if any + + :returns: An MQTTConnectionDroppedError if the connection was dropped, or None if the + connection was intentionally ended + :rtype: MQTTConnectionDroppedError or None + """ + async with self._mqtt_client.disconnected_cond: + await self._mqtt_client.disconnected_cond.wait_for(lambda: not self.connected) + return self._mqtt_client.previous_disconnection_cause() + + async def send_message(self, message: models.Message) -> None: + """Send a telemetry message to IoTHub. + + :param message: The Message to be sent + :type message: :class:`models.Message` + + :raises: MQTTError if there is an error sending the Message + :raises: ValueError if the size of the Message payload is too large + """ + # Format topic with message properties + telemetry_topic = mqtt_topic.get_telemetry_topic_for_publish( + self._device_id, self._module_id + ) + topic = mqtt_topic.insert_message_properties_in_topic( + topic=telemetry_topic, + system_properties=message.get_system_properties_dict(), + custom_properties=message.custom_properties, + ) + # Format payload based on content configuration + if message.content_type == "application/json": + str_payload = json.dumps(message.payload) + else: + str_payload = str(message.payload) + byte_payload = str_payload.encode(message.content_encoding) + # Send + logger.debug("Sending telemetry message to IoTHub...") + await self._mqtt_client.publish(topic, byte_payload) + logger.debug("Sending telemetry message succeeded") + + async def send_direct_method_response( + self, method_response: models.DirectMethodResponse + ) -> None: + """Send a direct method response to IoTHub. + + :param method_response: The DirectMethodResponse to be sent + :type method_response: :class:`models.DirectMethodResponse` + + :raises: MQTTError if there is an error sending the DirectMethodResponse + :raises: ValueError if the size of the DirectMethodResponse payload is too large + """ + topic = mqtt_topic.get_direct_method_response_topic_for_publish( + method_response.request_id, method_response.status + ) + payload = json.dumps(method_response.payload) + logger.debug( + "Sending direct method response to IoTHub... (rid: {})".format( + method_response.request_id + ) + ) + await self._mqtt_client.publish(topic, payload) + logger.debug( + "Sending direct method response succeeded (rid: {})".format(method_response.request_id) + ) + + async def send_twin_patch(self, patch: TwinPatch) -> None: + """Send a twin patch to IoTHub + + :param patch: The JSON patch to send + :type patch: dict, list, tuple, str, int, float, bool, None + + :raises: IoTHubError if an error response is received from IoT Hub + :raises: MQTTError if there is an error sending the twin patch + :raises: ValueError if the size of the the twin patch is too large + :raises: CancelledError if enabling twin responses is cancelled by network failure + """ + if not self._twin_responses_enabled: + await self._enable_twin_responses() + + request = await self._request_ledger.create_request() + try: + topic = mqtt_topic.get_twin_patch_topic_for_publish(request.request_id) + + # Send the patch to IoTHub + try: + logger.debug("Sending twin patch to IoTHub... (rid: {})".format(request.request_id)) + await self._mqtt_client.publish(topic, json.dumps(patch)) + except asyncio.CancelledError: + logger.warning( + "Attempt to send twin patch to IoTHub was cancelled while in flight. It may or may not have been received (rid: {})".format( + request.request_id + ) + ) + raise + except Exception: + logger.error( + "Sending twin patch to IoTHub failed (rid: {})".format(request.request_id) + ) + raise + + # Wait for a response from IoTHub + try: + logger.debug( + "Waiting for response to the twin patch from IoTHub... (rid: {})".format( + request.request_id + ) + ) + response = await request.get_response() + except asyncio.CancelledError: + logger.debug( + "Attempt to send twin patch to IoTHub was cancelled while waiting for response. If the response arrives, it will be discarded (rid: {})".format( + request.request_id + ) + ) + raise + + # Interpret response + logger.debug( + "Received twin patch response with status {} (rid: {})".format( + response.status, request.request_id + ) + ) + # TODO: should body be logged? Is there useful info there? + if response.status >= 300: + raise exc.IoTHubError( + "IoTHub responded to twin patch with a failed status - {}".format( + response.status + ) + ) + finally: + # If an exception caused exit before a pending request could be matched with a response + # then manually delete to prevent leaks. + if request.request_id in self._request_ledger: + await self._request_ledger.delete_request(request.request_id) + + async def get_twin(self) -> Twin: + """Request a full twin from IoTHub + + :returns: The full twin as a JSON object + :rtype: dict + + :raises: IoTHubError if an error response is received from IoT Hub + :raises: MQTTError if there is an error sending the twin request + :raises: CancelledError if enabling twin responses is cancelled by network failure + """ + if not self._twin_responses_enabled: + await self._enable_twin_responses() + + request = await self._request_ledger.create_request() + try: + topic = mqtt_topic.get_twin_request_topic_for_publish(request_id=request.request_id) + + # Send the twin request to IoTHub + try: + logger.debug( + "Sending get twin request to IoTHub... (rid: {})".format(request.request_id) + ) + await self._mqtt_client.publish(topic, " ") + except asyncio.CancelledError: + logger.warning( + "Attempt to send get twin request to IoTHub was cancelled while in flight. It may or may not have been received (rid: {})".format( + request.request_id + ) + ) + raise + except Exception: + logger.error( + "Sending get twin request to IoTHub failed (rid: {})".format(request.request_id) + ) + raise + + # Wait for a response from IoTHub + try: + logger.debug( + "Waiting to receive twin from IoTHub... (rid: {})".format(request.request_id) + ) + response = await request.get_response() + except asyncio.CancelledError: + logger.debug( + "Attempt to get twin from IoTHub was cancelled while waiting for a response. If the response arrives, it will be discarded (rid: {})".format( + request.request_id + ) + ) + raise + finally: + # If an exception caused exit before a pending request could be matched with a response + # then manually delete to prevent leaks. + if request.request_id in self._request_ledger: + await self._request_ledger.delete_request(request.request_id) + + # Interpret response + if response.status >= 300: + raise exc.IoTHubError( + "IoTHub responded to get twin request with a failed status - {}".format( + response.status + ) + ) + else: + logger.debug("Received twin from IoTHub (rid: {})".format(request.request_id)) + twin: Twin = json.loads(response.body) + return twin + + async def enable_c2d_message_receive(self) -> None: + """Enable the ability to receive C2D messages + + :raises: MQTTError if there is an error enabling C2D message receive + :raises: CancelledError if enabling C2D message receive is cancelled by network failure + :raises: IoTHubClientError if client not configured for a Device + """ + if self._module_id: + raise exc.IoTHubClientError("C2D messages not available on Modules") + logger.debug("Enabling receive for C2D messages...") + topic = mqtt_topic.get_c2d_topic_for_subscribe(self._device_id) + await self._mqtt_client.subscribe(topic) + logger.debug("C2D message receive enabled") + + async def disable_c2d_message_receive(self) -> None: + """Disable the ability to receive C2D messages + + :raises: MQTTError if there is an error disabling C2D message receive + :raises: CancelledError if disabling C2D message receive is cancelled by network failure + :raises: IoTHubClientError if client not configured for a Device + """ + if self._module_id: + raise exc.IoTHubClientError("C2D messages not available on Modules") + logger.debug("Disabling receive for C2D messages...") + topic = mqtt_topic.get_c2d_topic_for_subscribe(self._device_id) + await self._mqtt_client.unsubscribe(topic) + logger.debug("C2D message receive disabled") + + async def enable_input_message_receive(self) -> None: + """Enable the ability to receive input messages + + :raises: MQTTError if there is an error enabling input message receive + :raises: CancelledError if enabling input message receive is cancelled by network failure + :raises: IoTHubClientError if client not configured for a Module + """ + if not self._module_id: + raise exc.IoTHubClientError("Input messages not available on Devices") + logger.debug("Enabling receive for input messages...") + topic = mqtt_topic.get_input_topic_for_subscribe(self._device_id, self._module_id) + await self._mqtt_client.subscribe(topic) + logger.debug("Input message receive enabled") + + async def disable_input_message_receive(self) -> None: + """Disable the ability to receive input messages + + :raises: MQTTError if there is an error disabling input message receive + :raises: CancelledError if disabling input message receive is cancelled by network failure + :raises: IoTHubClientError if client not configured for a Module + """ + if not self._module_id: + raise exc.IoTHubClientError("Input messages not available on Devices") + logger.debug("Disabling receive for input messages...") + topic = mqtt_topic.get_input_topic_for_subscribe(self._device_id, self._module_id) + await self._mqtt_client.unsubscribe(topic) + logger.debug("Input message receive disabled") + + async def enable_direct_method_request_receive(self) -> None: + """Enable the ability to receive direct method requests + + :raises: MQTTError if there is an error enabling direct method request receive + :raises: CancelledError if enabling direct method request receive is cancelled by + network failure + """ + logger.debug("Enabling receive for direct method requests...") + topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + await self._mqtt_client.subscribe(topic) + logger.debug("Direct method request receive enabled") + + async def disable_direct_method_request_receive(self) -> None: + """Disable the ability to receive direct method requests + + :raises: MQTTError if there is an error disabling direct method request receive + :raises: CancelledError if disabling direct method request receive is cancelled by + network failure + """ + logger.debug("Disabling receive for direct method requests...") + topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + await self._mqtt_client.unsubscribe(topic) + logger.debug("Direct method request receive disabled") + + async def enable_twin_patch_receive(self) -> None: + """Enable the ability to receive twin patches + + :raises: MQTTError if there is an error enabling twin patch receive + :raises: CancelledError if enabling twin patch receive is cancelled by network failure + """ + logger.debug("Enabling receive for twin patches...") + topic = mqtt_topic.get_twin_patch_topic_for_subscribe() + await self._mqtt_client.subscribe(topic) + logger.debug("Twin patch receive enabled") + + async def disable_twin_patch_receive(self) -> None: + """Disable the ability to receive twin patches + + :raises: MQTTError if there is an error disabling twin patch receive + :raises: CancelledError if disabling twin patch receive is cancelled by network failure + """ + logger.debug("Disabling receive for twin patches...") + topic = mqtt_topic.get_twin_patch_topic_for_subscribe() + await self._mqtt_client.unsubscribe(topic) + logger.debug("Twin patch receive disabled") + + @property + def incoming_c2d_messages(self) -> AsyncGenerator[models.Message, None]: + """Generator that yields incoming C2D Messages""" + if not self._incoming_c2d_messages: + raise exc.IoTHubClientError("C2D Messages not available for Module") + else: + return self._incoming_c2d_messages + + @property + def incoming_input_messages(self) -> AsyncGenerator[models.Message, None]: + """Generator that yields incoming input Messages""" + if not self._incoming_input_messages: + raise exc.IoTHubClientError("Input Messages not available for Device") + else: + return self._incoming_input_messages + + @property + def incoming_direct_method_requests( + self, + ) -> AsyncGenerator[models.DirectMethodRequest, None]: + """Generator that yields incoming DirectMethodRequests""" + return self._incoming_direct_method_requests + + @property + def incoming_twin_patches(self) -> AsyncGenerator[TwinPatch, None]: + """Generator that yields incoming TwinPatches""" + return self._incoming_twin_patches + + @property + def connected(self) -> bool: + """Boolean indicating connection status""" + return self._mqtt_client.is_connected() + + +def _format_client_id(device_id: str, module_id: Optional[str] = None) -> str: + if module_id: + client_id = "{}/{}".format(device_id, module_id) + else: + client_id = device_id + return client_id + + +def _create_mqtt_client( + client_id: str, client_config: config.IoTHubClientConfig +) -> mqtt.MQTTClient: + logger.debug("Creating MQTTClient") + + logger.debug("Using {} as hostname".format(client_config.hostname)) + + if client_config.module_id: + logger.debug("Using IoTHub Module. Client ID is {}".format(client_id)) + else: + logger.debug("Using IoTHub Device. Client ID is {}".format(client_id)) + + if client_config.websockets: + logger.debug("Using MQTT over websockets") + transport = "websockets" + port = 443 + websockets_path = "/$iothub/websocket" + else: + logger.debug("Using MQTT over TCP") + transport = "tcp" + port = 8883 + websockets_path = None + + client = mqtt.MQTTClient( + client_id=client_id, + hostname=client_config.hostname, + port=port, + transport=transport, + keep_alive=client_config.keep_alive, + auto_reconnect=client_config.auto_reconnect, + reconnect_interval=DEFAULT_RECONNECT_INTERVAL, + ssl_context=client_config.ssl_context, + websockets_path=websockets_path, + proxy_options=client_config.proxy_options, + ) + + return client + + +def _format_username(hostname: str, client_id: str, product_info: str) -> str: + query_param_seq = [] + + # Apply query parameters (i.e. key1=value1&key2=value2...&keyN=valueN format) + if product_info.startswith(constant.DIGITAL_TWIN_PREFIX): # Digital Twin Stuff + query_param_seq.append(("api-version", constant.IOTHUB_API_VERSION)) + query_param_seq.append(("DeviceClientType", user_agent.get_iothub_user_agent())) + query_param_seq.append((constant.DIGITAL_TWIN_QUERY_HEADER, product_info)) + else: + query_param_seq.append(("api-version", constant.IOTHUB_API_VERSION)) + query_param_seq.append( + ("DeviceClientType", user_agent.get_iothub_user_agent() + product_info) + ) + + # NOTE: Client ID (including the device and/or module ids that are in it) + # is NOT url encoded as part of the username. Neither is the hostname. + # The sequence of key/value property pairs (query_param_seq) however, MUST have all + # keys and values URL encoded. + # See the repo wiki article for details: + # https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT) + username = "{hostname}/{client_id}/?{query_params}".format( + hostname=hostname, + client_id=client_id, + query_params=urllib.parse.urlencode(query_param_seq, quote_via=urllib.parse.quote), + ) + return username + + +def _create_iothub_message_from_mqtt_message(mqtt_message: mqtt.MQTTMessage) -> models.Message: + """Given an MQTTMessage, create and return a Message""" + properties = mqtt_topic.extract_properties_from_message_topic(mqtt_message.topic) + # Decode the payload based on content encoding in the topic. If not present, use utf-8 + content_encoding = properties.get("$.ce", "utf-8") + content_type = properties.get("$.ct", "text/plain") + payload = mqtt_message.payload.decode(content_encoding) + if content_type == "application/json": + payload = json.loads(payload) + return models.Message.create_from_properties_dict(payload=payload, properties=properties) + + +def _create_direct_method_request_from_mqtt_message( + mqtt_message: mqtt.MQTTMessage, +) -> models.DirectMethodRequest: + """Given an MQTTMessage, create and return a DirectMethodRequest""" + request_id = mqtt_topic.extract_request_id_from_direct_method_request_topic(mqtt_message.topic) + method_name = mqtt_topic.extract_name_from_direct_method_request_topic(mqtt_message.topic) + payload = json.loads(mqtt_message.payload.decode("utf-8")) + return models.DirectMethodRequest(request_id=request_id, name=method_name, payload=payload) + + +def _create_twin_patch_from_mqtt_message(mqtt_message: mqtt.MQTTMessage) -> TwinPatch: + """Given an MQTTMessage, create and return a TwinPatch""" + return json.loads(mqtt_message.payload.decode("utf-8")) diff --git a/azure-iot-device/azure/iot/device/iothub_session.py b/azure-iot-device/azure/iot/device/iothub_session.py new file mode 100644 index 000000000..4f221b65a --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub_session.py @@ -0,0 +1,453 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import asyncio +import contextlib +import functools +import ssl +from typing import Optional, Union, AsyncGenerator, Type, TypeVar, Awaitable +from types import TracebackType + +from . import exceptions as exc +from . import signing_mechanism as sm +from . import connection_string as cs +from . import sastoken as st +from . import config, models, custom_typing +from . import iothub_mqtt_client as mqtt + +_T = TypeVar("_T") + +# TODO: add tests for sastoken_ttl argument once we settle on a SAS strategy + + +def _requires_connection(f): + """Decorator to indicate a method requires the Session to already be connected.""" + + @functools.wraps(f) + def check_connection_wrapper(*args, **kwargs): + this = args[0] # a.k.a. self + if not this._mqtt_client.connected: + # NOTE: We need to raise an error directly if not connected because at MQTT + # Quality of Service (QoS) level 1, used at the lower levels of this stack, + # a MQTT Publish does not actually fail if not connected - instead, it waits + # for a connection to be established, and publishes the data once connected. + # + # This is not desirable behavior, so we check the connection state before + # any network operation over MQTT. While this issue only affects MQTT Publishes, + # and not MQTT Subscribes or Unsubscribes, we want this logic to be used + # on all methods that do MQTT operations for consistency. + raise exc.SessionError("IoTHubSession not connected") + else: + return f(*args, **kwargs) + + return check_connection_wrapper + + +class IoTHubSession: + def __init__( + self, + *, + hostname: str, # iothub_hostname? + device_id: str, + module_id: Optional[str] = None, + ssl_context: Optional[ssl.SSLContext] = None, + shared_access_key: Optional[str] = None, + sastoken_fn: Optional[custom_typing.FunctionOrCoroutine] = None, + sastoken_ttl: int = 3600, + **kwargs, + ) -> None: + """ + :param str device_id: The device identity for the IoT Hub device containing the + IoT Hub module + :param str module_id: The module identity for the IoT Hub module + :param str hostname: Hostname of the IoT Hub or IoT Edge the device should connect to + :param ssl_context: Custom SSL context to be used when establishing a connection. + If not provided, a default one will be used + :type ssl_context: :class:`ssl.SSLContext` + :param str shared_access_key: A key that can be used to generate SAS Tokens + :param sastoken_fn: A function or coroutine function that takes no arguments and returns + a SAS token string when invoked + :param sastoken_ttl: Time-to-live (in seconds) for SAS tokens generated when using + 'shared_access_key' authentication. + If using this auth type, a new Session will need to be created once this time expires. + Default is 3600 seconds (1 hour). + + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword str product_info: Arbitrary product information which will be included in the + User-Agent string + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: ValueError if an invalid combination of parameters are provided + :raises: ValueError if an invalid 'symmetric_key' is provided + :raises: TypeError if an invalid keyword argument is provided + """ + # Validate parameters + _validate_kwargs(**kwargs) + if shared_access_key and sastoken_fn: + raise ValueError( + "Incompatible authentication - cannot provide both 'shared_access_key' and 'sastoken_fn'" + ) + if not shared_access_key and not sastoken_fn and not ssl_context: + raise ValueError( + "Missing authentication - must provide one of 'shared_access_key', 'sastoken_fn' or 'ssl_context'" + ) + + # Set up SAS auth (if using) + generator: Optional[st.SasTokenGenerator] + # NOTE: Need to keep a reference to the SasTokenProvider so we can stop it during cleanup + self._sastoken_provider: Optional[st.SasTokenProvider] + if shared_access_key: + uri = _format_sas_uri(hostname=hostname, device_id=device_id, module_id=module_id) + signing_mechanism = sm.SymmetricKeySigningMechanism(shared_access_key) + generator = st.InternalSasTokenGenerator( + signing_mechanism=signing_mechanism, uri=uri, ttl=sastoken_ttl + ) + self._sastoken_provider = st.SasTokenProvider(generator) + elif sastoken_fn: + generator = st.ExternalSasTokenGenerator(sastoken_fn) + self._sastoken_provider = st.SasTokenProvider(generator) + else: + self._sastoken_provider = None + + # Create a default SSLContext if not provided + if not ssl_context: + ssl_context = _default_ssl_context() + + # Instantiate the MQTTClient + client_config = config.IoTHubClientConfig( + hostname=hostname, + device_id=device_id, + module_id=module_id, + sastoken_provider=self._sastoken_provider, + ssl_context=ssl_context, + auto_reconnect=False, # We do not reconnect in a Session + **kwargs, + ) + self._mqtt_client = mqtt.IoTHubMQTTClient(client_config) + + # This task is used to propagate dropped connections through receiver generators + # It will be set upon context manager entry and cleared upon exit + # NOTE: If we wanted to design lower levels of the stack to be specific to our + # Session design pattern, this could happen lower (and it would be simpler), but it's + # up here so we can be more implementation-generic down the stack. + self._wait_for_disconnect_task: Optional[ + asyncio.Task[Optional[exc.MQTTConnectionDroppedError]] + ] = None + + async def __aenter__(self) -> "IoTHubSession": + # First, if using SAS auth, start up the provider + if self._sastoken_provider: + # NOTE: No try/except block is needed here because in the case of failure there is not + # yet anything that we would need to clean up. + await self._sastoken_provider.start() + + # Start/connect + try: + await self._mqtt_client.start() + await self._mqtt_client.connect() + except (Exception, asyncio.CancelledError): + # Stop/cleanup if something goes wrong + await self._stop_all() + raise + + self._wait_for_disconnect_task = asyncio.create_task( + self._mqtt_client.wait_for_disconnect() + ) + + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + traceback: TracebackType, + ) -> None: + try: + await self._mqtt_client.disconnect() + finally: + # TODO: is it dangerous to cancel / remove this task? + if self._wait_for_disconnect_task: + self._wait_for_disconnect_task.cancel() + self._wait_for_disconnect_task = None + await self._stop_all() + + async def _stop_all(self) -> None: + try: + await self._mqtt_client.stop() + finally: + if self._sastoken_provider: + await self._sastoken_provider.stop() + + @classmethod + def from_connection_string( + cls, + connection_string: str, + ssl_context: Optional[ssl.SSLContext] = None, + sastoken_ttl: int = 3600, + **kwargs, + ) -> "IoTHubSession": + """Instantiate an IoTHubSession using an IoT Hub device or module connection string + + :returns: A new instance of IoTHubSession + :rtype: IoTHubSession + + :param str connection_string: The IoT Hub device connection string + :param ssl_context: Custom SSL context to be used when establishing a connection. + If not provided, a default one will be used + :type ssl_context: :class:`ssl.SSLContext` + :param sastoken_ttl: Time-to-live (in seconds) for SAS tokens used for authentication. + A new Session will need to be created once this time expires. + Default is 3600 seconds (1 hour). + + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword str product_info: Arbitrary product information which will be included in the + User-Agent string + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: ValueError if the provided connection string is invalid + :raises: TypeError if an invalid keyword argument is provided + """ + cs_obj = cs.ConnectionString(connection_string) + if cs_obj.get(cs.X509, "").lower() == "true" and ssl_context is None: + raise ValueError( + "Connection string indicates X509 certificate authentication, but no ssl_context provided" + ) + if cs.GATEWAY_HOST_NAME in cs_obj: + hostname = cs_obj[cs.GATEWAY_HOST_NAME] + else: + hostname = cs_obj[cs.HOST_NAME] + return cls( + hostname=hostname, + device_id=cs_obj[cs.DEVICE_ID], + module_id=cs_obj.get(cs.MODULE_ID), + shared_access_key=cs_obj.get(cs.SHARED_ACCESS_KEY), + ssl_context=ssl_context, + sastoken_ttl=sastoken_ttl, + **kwargs, + ) + + @_requires_connection + async def send_message(self, message: Union[str, models.Message]) -> None: + """Send a telemetry message to IoT Hub + + :param message: Message to send. If not a Message object, will be used as the payload of + a new Message object. + :type message: str or :class:`Message` + + :raises: MQTTError if there is an error sending the Message + :raises: ValueError if the size of the Message payload is too large + :raises: RuntimeError if not connected when invoked + """ + if not isinstance(message, models.Message): + message = models.Message(message) + await self._add_disconnect_interrupt_to_coroutine(self._mqtt_client.send_message(message)) + + @_requires_connection + async def send_direct_method_response( + self, method_response: models.DirectMethodResponse + ) -> None: + """Send a response to a direct method request + + :param method_response: The response object containing information regarding the result of + the direct method invocation + :type method_response: :class:`DirectMethodResponse` + + :raises: MQTTError if there is an error sending the DirectMethodResponse + :raises: ValueError if the size of the DirectMethodResponse payload is too large + """ + await self._add_disconnect_interrupt_to_coroutine( + self._mqtt_client.send_direct_method_response(method_response) + ) + + @_requires_connection + async def update_reported_properties(self, patch: custom_typing.TwinPatch) -> None: + """Update the reported properties of the Twin + + :param dict patch: JSON object containing the updates to the Twin reported properties + + :raises: IoTHubError if an error response is received from IoT Hub + :raises: MQTTError if there is an error sending the updated reported properties + :raises: ValueError if the size of the the reported properties patch too large + :raises: CancelledError if enabling responses from IoT Hub is cancelled by network failure + """ + await self._add_disconnect_interrupt_to_coroutine(self._mqtt_client.send_twin_patch(patch)) + + @_requires_connection + async def get_twin(self) -> custom_typing.Twin: + """Retrieve the full Twin data + + :returns: Twin as a JSON object + :rtype: dict + + :raises: IoTHubError if a error response is received from IoTHub + :raises: MQTTError if there is an error sending the request + :raises: CancelledError if enabling responses from IoT Hub is cancelled by network failure + """ + return await self._add_disconnect_interrupt_to_coroutine(self._mqtt_client.get_twin()) + + @contextlib.asynccontextmanager + @_requires_connection + async def messages(self) -> AsyncGenerator[AsyncGenerator[models.Message, None], None]: + """Returns an async generator of incoming C2D messages""" + await self._mqtt_client.enable_c2d_message_receive() + try: + yield self._add_disconnect_interrupt_to_generator( + self._mqtt_client.incoming_c2d_messages + ) + finally: + try: + if self._mqtt_client.connected: + await self._mqtt_client.disable_c2d_message_receive() + except exc.MQTTError: + # i.e. not connected + # This error would be expected if a disconnection has ocurred + pass + + @contextlib.asynccontextmanager + @_requires_connection + async def direct_method_requests( + self, + ) -> AsyncGenerator[AsyncGenerator[models.DirectMethodRequest, None], None]: + """Returns an async generator of incoming direct method requests""" + await self._mqtt_client.enable_direct_method_request_receive() + try: + yield self._add_disconnect_interrupt_to_generator( + self._mqtt_client.incoming_direct_method_requests + ) + finally: + try: + if self._mqtt_client.connected: + await self._mqtt_client.disable_direct_method_request_receive() + except exc.MQTTError: + # i.e. not connected + # This error would be expected if a disconnection has ocurred + pass + + @contextlib.asynccontextmanager + @_requires_connection + async def desired_property_updates( + self, + ) -> AsyncGenerator[AsyncGenerator[custom_typing.TwinPatch, None], None]: + """Returns an async generator of incoming twin desired property patches""" + await self._mqtt_client.enable_twin_patch_receive() + try: + yield self._add_disconnect_interrupt_to_generator( + self._mqtt_client.incoming_twin_patches + ) + finally: + try: + if self._mqtt_client.connected: + await self._mqtt_client.disable_twin_patch_receive() + except exc.MQTTError: + # i.e. not connected + # This error would be expected if a disconnection has ocurred + pass + + def _add_disconnect_interrupt_to_generator( + self, generator: AsyncGenerator[_T, None] + ) -> AsyncGenerator[_T, None]: + """Wrap a generator in another generator that will either return the next item yielded by + the original generator, or raise error in the event of disconnect + """ + + async def wrapping_generator(): + while True: + new_item_t = asyncio.create_task(generator.__anext__()) + done, _ = await asyncio.wait( + [new_item_t, self._wait_for_disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if self._wait_for_disconnect_task in done: + new_item_t.cancel() + cause = self._wait_for_disconnect_task.result() + if cause is not None: + raise cause + else: + raise asyncio.CancelledError("Cancelled by disconnect") + else: + yield new_item_t.result() + + return wrapping_generator() + + def _add_disconnect_interrupt_to_coroutine(self, coro: Awaitable[_T]) -> Awaitable[_T]: + """Wrap a coroutine in another coroutine that will either return the result of the original + coroutine, or raise error in the event of disconnect + """ + + async def wrapping_coroutine(): + original_task = asyncio.create_task(coro) + done, _ = await asyncio.wait( + [original_task, self._wait_for_disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if self._wait_for_disconnect_task in done: + original_task.cancel() + cause = self._wait_for_disconnect_task.result() + if cause is not None: + raise cause + else: + raise asyncio.CancelledError("Cancelled by disconnect") + else: + return await original_task + + return wrapping_coroutine() + + @property + def connected(self) -> bool: + return self._mqtt_client.connected + + @property + def device_id(self) -> str: + return self._mqtt_client._device_id + + @property + def module_id(self) -> Optional[str]: + return self._mqtt_client._module_id + + +def _validate_kwargs(exclude=[], **kwargs) -> None: + """Helper function to validate user provided kwargs. + Raises TypeError if an invalid option has been provided""" + valid_kwargs = [ + # "auto_reconnect", + "keep_alive", + "product_info", + "proxy_options", + "websockets", + ] + + for kwarg in kwargs: + if (kwarg not in valid_kwargs) or (kwarg in exclude): + # NOTE: TypeError is the conventional error that is returned when an invalid kwarg is + # supplied. It feels like it should be a ValueError, but it's not. + raise TypeError("Unsupported keyword argument: '{}'".format(kwarg)) + + +def _format_sas_uri(hostname: str, device_id: str, module_id: Optional[str]) -> str: + """Format the SAS URI for using IoT Hub""" + if module_id: + return "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=hostname, device_id=device_id, module_id=module_id + ) + else: + return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) + + +def _default_ssl_context() -> ssl.SSLContext: + """Return a default SSLContext""" + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + ssl_context.load_default_certs() + return ssl_context diff --git a/azure-iot-device/azure/iot/device/models.py b/azure-iot-device/azure/iot/device/models.py new file mode 100644 index 000000000..6e609ae39 --- /dev/null +++ b/azure-iot-device/azure/iot/device/models.py @@ -0,0 +1,206 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from typing import Optional, Dict, Union +from .custom_typing import JSONSerializable +from . import constant + +# TODO: Should Message property dictionaries be TypeDicts? + + +class Message: + """Represents a message to or from IoTHub + + :ivar payload: The data that constitutes the payload + :ivar content_encoding: Content encoding of the message data. Can be 'utf-8', 'utf-16' or 'utf-32' + :ivar content_type: Content type property used to route messages with the message-body. Can be 'application/json' + :ivar message id: A user-settable identifier for the message used for request-reply patterns. Format: A case-sensitive string (up to 128 characters long) of ASCII 7-bit alphanumeric characters + {'-', ':', '.', '+', '%', '_', '#', '*', '?', '!', '(', ')', ',', '=', '@', ';', '$', '''} + :ivar custom_properties: Dictionary of custom message properties. The keys and values of these properties will always be string. + :ivar output_name: Name of the output that the message is being sent to. + :ivar input_name: Name of the input that the message was received on. + :ivar ack: Indicates the type of feedback generation used by IoTHub + :ivar expiry_time_utc: Date and time of message expiration in UTC format + :ivar user_id: An ID to specify the origin of messages + :ivar correlation_id: A property in a response message that typically contains the message_id of the request, in request-reply patterns + """ + + def __init__( + self, + payload: Union[str, JSONSerializable], + content_encoding: str = "utf-8", + content_type: str = "text/plain", + output_name: Optional[str] = None, + ) -> None: + """ + Initializer for Message + + :param payload: The JSON serializable data that constitutes the payload. + :param str content_encoding: Content encoding of the message payload. + Acceptable values are 'utf-8', 'utf-16' and 'utf-32' + :param str content_type: Content type of the message payload. + Acceptable values are 'text/plain' and 'application/json' + :param str output_name: Name of the output that the message is being sent to. + """ + # Sanitize + if content_encoding not in ["utf-8", "utf-16", "utf-32"]: + raise ValueError( + "Invalid content encoding. Supported codecs are 'utf-8', 'utf-16' and 'utf-32'" + ) + if content_type not in ["text/plain", "application/json"]: + raise ValueError( + "Invalid content type. Supported types are 'text/plain' and 'application/json'" + ) + + # All Messages + self.payload = payload + self.content_encoding = content_encoding + self.content_type = content_type + self.message_id: Optional[str] = None + self.custom_properties: Dict[str, str] = {} + + # Outgoing Messages (D2C/Output) + self.output_name = output_name + self._iothub_interface_id: Optional[str] = None + + # Incoming Messages (C2D/Input) + self.input_name: Optional[str] = None + self.ack: Optional[str] = None + self.expiry_time_utc: Optional[str] = None + self.user_id: Optional[str] = None + self.correlation_id: Optional[str] = None + + def __str__(self) -> str: + return str(self.payload) + + @property + def iothub_interface_id(self): + return self._iothub_interface_id + + def set_as_security_message(self) -> None: + """ + Set the message as a security message. + """ + self._iothub_interface_id = constant.SECURITY_MESSAGE_INTERFACE_ID + + def get_system_properties_dict(self) -> Dict[str, str]: + """Return a dictionary of system properties""" + d = {} + # All messages + if self.message_id: + d["$.mid"] = self.message_id + if self.content_encoding: + d["$.ce"] = self.content_encoding + if self.content_type: + d["$.ct"] = self.content_type + # Outgoing Messages (D2C/Output) + if self.output_name: + d["$.on"] = self.output_name + if self._iothub_interface_id: + d["$.ifid"] = self._iothub_interface_id + # Incoming Messages (C2D/Input) + if self.input_name: + d["$.to"] = self.input_name + if self.ack: + d["iothub-ack"] = self.ack + if self.expiry_time_utc: + d["$.exp"] = self.expiry_time_utc + if self.user_id: + d["$.uid"] = self.user_id + if self.correlation_id: + d["$.cid"] = self.correlation_id + return d + + @classmethod + # TODO: should this just replace the __init__? + def create_from_properties_dict( + cls, payload: JSONSerializable, properties: Dict[str, str] + ) -> "Message": + message = cls(payload) + + for key in properties: + # All messages + if key == "$.mid": + message.message_id = properties[key] + elif key == "$.ce": + message.content_encoding = properties[key] + elif key == "$.ct": + message.content_type = properties[key] + # Outgoing Messages (D2C/Output) + elif key == "$.on": + message.output_name = properties[key] + elif key == "$.ifid": + message._iothub_interface_id = properties[key] + # Incoming Messages (C2D/Input) + elif key == "$.to": + message.input_name = properties[key] + elif key == "iothub-ack": + message.ack = properties[key] + elif key == "$.exp": + message.expiry_time_utc = properties[key] + elif key == "$.uid": + message.user_id = properties[key] + elif key == "$.cid": + message.correlation_id = properties[key] + else: + message.custom_properties[key] = properties[key] + + return message + + +class DirectMethodRequest: + """Represents a request to invoke a direct method. + + :ivar str request_id: The request id. + :ivar str name: The name of the method to be invoked. + :ivar dict payload: The JSON payload being sent with the request. + :type payload: dict, str, int, float, bool, or None (JSON compatible values) + """ + + def __init__(self, request_id: str, name: str, payload: JSONSerializable) -> None: + """Initializer for a DirectMethodRequest. + + :param str request_id: The request id. + :param str name: The name of the method to be invoked + :param payload: The JSON payload being sent with the request. + :type payload: dict, str, int, float, bool, or None (JSON compatible values) + """ + self.request_id = request_id + self.name = name + self.payload = payload + + +class DirectMethodResponse: + """Represents a response to a direct method. + + :ivar str request_id: The request id of the DirectMethodRequest being responded to. + :ivar int status: The status of the execution of the DirectMethodRequest. + :ivar payload: The JSON payload to be sent with the response. + :type payload: dict, str, int, float, bool, or None (JSON compatible values) + """ + + def __init__(self, request_id: str, status: int, payload: JSONSerializable = None) -> None: + """Initializer for DirectMethodResponse. + + :param str request_id: The request id of the DirectMethodRequest being responded to. + :param int status: The status of the execution of the DirectMethodRequest. + :param payload: The JSON payload to be sent with the response. (OPTIONAL) + :type payload: dict, str, int, float, bool, or None (JSON compatible values) + """ + self.request_id = request_id + self.status = status + self.payload = payload + + @classmethod + def create_from_method_request( + cls, method_request: DirectMethodRequest, status: int, payload: JSONSerializable = None + ): + """Factory method for creating a DirectMethodResponse from a DirectMethodRequest. + + :param method_request: The DirectMethodRequest object to respond to. + :type method_request: DirectMethodRequest. + :param int status: The status of the execution of the DirectMethodRequest. + :type payload: dict, str, int, float, bool, or None (JSON compatible values) + """ + return cls(request_id=method_request.request_id, status=status, payload=payload) diff --git a/azure-iot-device/azure/iot/device/mqtt_client.py b/azure-iot-device/azure/iot/device/mqtt_client.py new file mode 100644 index 000000000..3418e03d1 --- /dev/null +++ b/azure-iot-device/azure/iot/device/mqtt_client.py @@ -0,0 +1,878 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import functools +import logging +import paho.mqtt.client as mqtt # type: ignore +from paho.mqtt.client import MQTTMessage # noqa: F401 (Importing directly to re-export) +import ssl +from typing import Any, Dict, AsyncGenerator, Optional, Union +from .config import ProxyOptions + + +logger = logging.getLogger(__name__) + + +# NOTE: Paho can return a lot of rc values. However, most of them shouldn't happen. +# Here are the ones that we can expect for each method. +expected_connect_rc = [mqtt.MQTT_ERR_SUCCESS] +expected_disconnect_rc = [mqtt.MQTT_ERR_SUCCESS, mqtt.MQTT_ERR_NO_CONN] +expected_subscribe_rc = [mqtt.MQTT_ERR_SUCCESS, mqtt.MQTT_ERR_NO_CONN] +expected_unsubscribe_rc = [mqtt.MQTT_ERR_SUCCESS, mqtt.MQTT_ERR_NO_CONN] +expected_publish_rc = [mqtt.MQTT_ERR_SUCCESS, mqtt.MQTT_ERR_NO_CONN, mqtt.MQTT_ERR_QUEUE_SIZE] + +# Additionally, some are returned only via handler +expected_on_disconnect_rc = [ + mqtt.MQTT_ERR_SUCCESS, + mqtt.MQTT_ERR_CONN_REFUSED, + mqtt.MQTT_ERR_CONN_LOST, + mqtt.MQTT_ERR_KEEPALIVE, +] +expected_on_connect_rc = [ + mqtt.CONNACK_ACCEPTED, + mqtt.CONNACK_REFUSED_PROTOCOL_VERSION, + mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED, + mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE, + mqtt.CONNACK_REFUSED_BAD_USERNAME_PASSWORD, + mqtt.CONNACK_REFUSED_NOT_AUTHORIZED, +] + + +# NOTE: Paho has two kinds of errors and associated error codes: MQTT and MQTT CONNACK: +# +# - MQTT CONNACK errors only occur when attempting to establish a connection over MQTT, and thus +# map onto the MQTTConnectionFailedError exception defined below. +# +# - MQTT errors occur everywhere else, and in this implementation map to two different types of +# exceptions depending on context: +# +# - MQTT errors that result from a failed client operation (publish, subscribe, unsubscribe) +# map onto the MQTTError exception defined below. +# +# - MQTT errors that result from a lost connection that had previously been established map +# onto the MQTTConnectionDroppedError exception defined below. +# +# The reason for this artificially created split within basic MQTT errors is to provide +# flexibility to the user of this library in order to write try/except code that can easily +# differentiate between the cause of the error, and handle it accordingly. While it is true that +# MQTTConnectionDroppedErrors are exclusively retrieved via the `.previous_disconnection_cause()` +# method, an application that is sufficiently complex may want to re-raise such exceptions, +# and making a clear semantic split in error definitions can greatly simplify higher level +# try/except logic. + + +class MQTTError(Exception): + """Represents a failure with a Paho-given error rc code""" + + def __init__(self, rc): + self.rc = rc + super().__init__(mqtt.error_string(rc)) + + +class MQTTConnectionDroppedError(Exception): + """Represents a failure indicating a lost connection with a Paho-given error rc code""" + + def __init__(self, rc): + self.rc = rc + super().__init__(mqtt.error_string(rc)) + + +class MQTTConnectionFailedError(Exception): + """Represents a failure to connect. + Can have a Paho-given connack rc code, or a message""" + + def __init__(self, rc=None, message=None, fatal=False): + if rc and message: + raise ValueError("rc and message are mutually exclusive") + self.rc = rc + self.fatal = fatal + if rc: + message = mqtt.connack_string(rc) + super().__init__(message) + + +class MQTTClient: + """ + Provides an async MQTT message broker interface + + This client currently only supports operations at a QoS (Quality of Service) of 1 + """ + + def __init__( + self, + client_id: str, + hostname: str, + port: int, + transport: str = "tcp", + keep_alive: int = 60, + auto_reconnect: bool = False, + reconnect_interval: int = 10, + ssl_context: Optional[ssl.SSLContext] = None, + websockets_path: Optional[str] = None, + proxy_options: Optional[ProxyOptions] = None, + ) -> None: + """ + Constructor to instantiate client. + + :param str client_id: The id of the client connecting to the broker. + :param str hostname: Hostname or IP address of the remote broker. + :param int port: Network port to connect to + :param str transport: "tcp" for TCP or "websockets" for WebSockets. + :param int keep_alive: Number of seconds before connection timeout. + :param bool auto_reconnect: Indicates whether or not client should reconnect when a + connection is unexpectedly dropped. + :param int reconnect_interval: Number of seconds between reconnect attempts + :param ssl_context: The SSL Context to use with MQTT. If not provided will use default. + :type ssl_context: :class:`ssl.SSLContext` + :param str websockets_path: Path for websocket connection. + Starts with '/' and should be the endpoint of the mqtt connection on the remote server. + :param proxy_options: Options for sending traffic through proxy servers. + :type proxy_options: :class:`azure.iot.device.common.ProxyOptions` + """ + # Configuration + self._hostname = hostname + self._port = port + self._keep_alive = keep_alive + self._auto_reconnect = auto_reconnect + self._reconnect_interval = reconnect_interval + + # Client + self._mqtt_client = self._create_mqtt_client( + client_id, transport, ssl_context, proxy_options, websockets_path + ) + + # Event Loop + self._event_loop = asyncio.get_running_loop() + + # State + # NOTE: These values do not need to be protected by locks since the code paths that + # modify them cannot be invoked in parallel. + self._connected = False + self._desire_connection = False + self._disconnection_cause: Optional[MQTTConnectionDroppedError] = None + + # Synchronization + self.connected_cond = asyncio.Condition() + self.disconnected_cond = asyncio.Condition() + self._connection_lock = asyncio.Lock() + self._mid_tracker_lock = asyncio.Lock() + + # Tasks/Futures + self._network_loop: Optional[asyncio.Future] = None + self._reconnect_daemon: Optional[asyncio.Task] = None + # NOTE: pending connect is protected by the connection lock + # Other pending ops are protected by the _mid_tracker_lock + self._pending_connect: Optional[asyncio.Future] = None + self._pending_subs: Dict[int, asyncio.Future] = {} + self._pending_unsubs: Dict[int, asyncio.Future] = {} + self._pending_pubs: Dict[int, asyncio.Future] = {} + + # Incoming Data + self._incoming_messages: asyncio.Queue[mqtt.MQTTMessage] = asyncio.Queue() + self._incoming_filtered_messages: Dict[str, asyncio.Queue[mqtt.MQTTMessage]] = {} + + def _create_mqtt_client( + self, + client_id: str, + transport: str, + ssl_context: Optional[ssl.SSLContext], + proxy_options: Optional[ProxyOptions], + websockets_path: Optional[str], + ) -> mqtt.Client: + """ + Create the MQTT client object and assign all necessary event handler callbacks. + """ + logger.debug("Creating Paho client") + + # Instantiate the client + mqtt_client = mqtt.Client( + client_id=client_id, + clean_session=False, + protocol=mqtt.MQTTv311, + transport=transport, + reconnect_on_failure=False, # We handle reconnect logic ourselves + ) + if transport == "websockets" and websockets_path: + logger.debug("Configuring Paho client for connecting using MQTT over websockets") + mqtt_client.ws_set_options(path=websockets_path) + else: + logger.debug("Configuring Paho client for connecting using MQTT over TCP") + + if proxy_options: + logger.debug("Configuring custom proxy options on Paho client") + mqtt_client.proxy_set( + proxy_type=proxy_options.proxy_type_socks, + proxy_addr=proxy_options.proxy_address, + proxy_port=proxy_options.proxy_port, + proxy_username=proxy_options.proxy_username, + proxy_password=proxy_options.proxy_password, + ) + + mqtt_client.enable_logger(logging.getLogger("paho")) + + # Configure TLS/SSL. If the value passed is None, will use default. + mqtt_client.tls_set_context(context=ssl_context) + + def on_connect(client: mqtt.Client, userdata: Any, flags: Dict[str, int], rc: int) -> None: + message = mqtt.connack_string(rc) + logger.debug("Connect Response: rc {} - {}".format(rc, message)) + if rc not in expected_on_connect_rc: + logger.warning("Connect Response rc {} was unexpected".format(rc)) + + # Change state, report result, and notify connection established + async def set_result() -> None: + if rc == mqtt.CONNACK_ACCEPTED: + logger.debug("Client State: CONNECTED") + self._connected = True + self._desire_connection = True + self._disconnection_cause = None + async with self.connected_cond: + self.connected_cond.notify_all() + if self._pending_connect: + self._pending_connect.set_result(rc) + else: + logger.warning( + "Connect response received without outstanding attempt (likely was cancelled)" + ) + + f = asyncio.run_coroutine_threadsafe(set_result(), self._event_loop) + # Need to wait for this one to finish since we don't want to let another + # Paho handler invoke until we know the connection state has been set. + f.result() + + def on_disconnect(client: mqtt.Client, userdata: Any, rc: int) -> None: + rc_msg = mqtt.error_string(rc) + + # NOTE: It's not generally safe to use .is_connected() to determine what to do, since + # the value could change at any time. However, it IS safe to do so here. + # This handler, as well as .on_connect() above, are the only two functions that can + # change the value. They are both invoked on Paho's network loop, which is + # single-threaded. This means there cannot be overlapping invocations that would + # change the value, and thus that the value will not change during execution of this + # block. + if not self.is_connected(): + if rc == mqtt.MQTT_ERR_CONN_REFUSED: + # When Paho receives a failure response to a connect, the disconnect + # handler is also called. + # But we don't wish to issue spurious notifications or other behaviors + logger.debug( + "Connect Failure Disconnect Response: rc {} - {}".format(rc, rc_msg) + ) + else: + # Double disconnect. Suppress. + # Sometimes Paho disconnects twice. Why? Who knows. + # But we don't wish to issue spurious notifications or other behaviors + logger.debug("Double Disconnect Response: rc {} - {}".format(rc, rc_msg)) + else: + if rc == mqtt.MQTT_ERR_SUCCESS: + logger.debug("Disconnect Response: rc {} - {}".format(rc, rc_msg)) + else: + logger.debug("Unexpected Disconnect: rc {} - {}".format(rc, rc_msg)) + + # Change state and notify tasks waiting on disconnect + async def set_disconnected() -> None: + logger.debug("Client State: DISCONNECTED") + self._connected = False + if rc != mqtt.MQTT_ERR_SUCCESS: + self._disconnection_cause = MQTTConnectionDroppedError(rc=rc) + async with self.disconnected_cond: + self.disconnected_cond.notify_all() + + f = asyncio.run_coroutine_threadsafe(set_disconnected(), self._event_loop) + # Need to wait for this one to finish since we don't want to let another + # Paho handler invoke until we know the connection state has been set. + f.result() + + # Cancel pending subscribes and unsubscribes only. + # Publishes can survive a disconnect. + async def cancel_pending() -> None: + if len(self._pending_subs) != 0 or len(self._pending_unsubs) != 0: + async with self._mid_tracker_lock: + logger.debug("Cancelling pending subscribes") + mids = self._pending_subs.keys() + for mid in mids: + self._pending_subs[mid].cancel() + self._pending_subs.clear() + logger.debug("Cancelling pending unsubscribes") + mids = self._pending_unsubs.keys() + for mid in mids: + self._pending_unsubs[mid].cancel() + self._pending_unsubs.clear() + + # NOTE: This coroutine might not be able to finish right away due to the + # mid_tracker_lock. Don't wait on it's completion or it may deadlock. + asyncio.run_coroutine_threadsafe(cancel_pending(), self._event_loop) + + def on_subscribe(client: mqtt.Client, userdata: Any, mid: int, granted_qos: int) -> None: + logger.debug("SUBACK received for mid {}".format(mid)) + + async def complete_sub() -> None: + async with self._mid_tracker_lock: + try: + f = self._pending_subs[mid] + f.set_result(True) + except KeyError: + logger.warning("Unexpected SUBACK received for mid {}".format(mid)) + + # NOTE: The complete_sub() coroutine cannot finish right away due to the + # mid_tracker_lock being held by the invocation of .subscribe(), waiting for a result. + # Do not wait on the completion of complete_sub() or this callback will deadlock the + # Paho network loop. Just schedule the eventual completion, and keep it moving. + asyncio.run_coroutine_threadsafe(complete_sub(), self._event_loop) + + def on_unsubscribe(client: mqtt.Client, userdata: Any, mid: int) -> None: + logger.debug("UNSUBACK received for mid {}".format(mid)) + + async def complete_unsub() -> None: + async with self._mid_tracker_lock: + try: + f = self._pending_unsubs[mid] + f.set_result(True) + except KeyError: + logger.warning("Unexpected UNSUBACK received for mid {}".format(mid)) + + # NOTE: The complete_unsub() coroutine cannot finish right away due to the + # mid_tracker_lock being held by the invocation of .unsubscribe(), waiting for a result. + # Do not wait on the completion of complete_unsub() or this callback will deadlock the + # Paho network loop. Just schedule the eventual completion, and keep it moving. + asyncio.run_coroutine_threadsafe(complete_unsub(), self._event_loop) + + def on_publish(client: mqtt.Client, userdata: Any, mid: int) -> None: + logger.debug("PUBACK received for mid {}".format(mid)) + + async def complete_pub() -> None: + async with self._mid_tracker_lock: + try: + f = self._pending_pubs[mid] + f.set_result(True) + except KeyError: + logger.warning("Unexpected PUBACK received for mid {}".format(mid)) + + # NOTE: The complete_pub() coroutine cannot finish right away due to the + # mid_tracker_lock being held by the invocation of .publish(), waiting for a result. + # Do not wait on the completion of complete_pub() or this callback will deadlock the + # Paho network loop. Just schedule the eventual completion, and keep it moving. + asyncio.run_coroutine_threadsafe(complete_pub(), self._event_loop) + + def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> None: + logger.debug("Incoming MQTT Message received on {}".format(message.topic)) + + async def add_to_queue() -> None: + await self._incoming_messages.put(message) + + asyncio.run_coroutine_threadsafe(add_to_queue(), self._event_loop) + + mqtt_client.on_connect = on_connect + mqtt_client.on_disconnect = on_disconnect + mqtt_client.on_subscribe = on_subscribe + mqtt_client.on_unsubscribe = on_unsubscribe + mqtt_client.on_publish = on_publish + mqtt_client.on_message = on_message + + return mqtt_client + + async def _reconnect_loop(self) -> None: + """Reconnect logic""" + logger.debug("Reconnect Daemon starting...") + try: + while True: + async with self.disconnected_cond: + await self.disconnected_cond.wait_for( + lambda: not self.is_connected() and self._desire_connection + ) + try: + logger.debug("Reconnect Daemon attempting to reconnect...") + await self.connect() + logger.debug("Reconnect Daemon reconnect attempt succeeded") + except MQTTConnectionFailedError as e: + if not e.fatal: + interval = self._reconnect_interval + logger.debug( + "Reconnect Daemon reconnect attempt failed. Trying again in {} seconds".format( + interval + ) + ) + await asyncio.sleep(interval) + else: + logger.error("Reconnect failure was fatal - cannot reconnect") + logger.error(str(e)) + break + except asyncio.CancelledError: + logger.debug("Reconnect Daemon was cancelled") + raise + + def _network_loop_running(self) -> bool: + """Internal helper method to assess network loop""" + if self._network_loop and not self._network_loop.done(): + return True + else: + return False + + def is_connected(self) -> bool: + """ + Returns a boolean indicating whether the MQTT client is currently connected. + + Note that this value is only accurate as of the time it returns. It could change at + any point. + """ + return self._connected + + def previous_disconnection_cause(self) -> Optional[MQTTConnectionDroppedError]: + """ + Returns an MQTTConnectionDroppedError from the previous disconnection if it was unexpected as a result of + a connection drop. + Returns None if the previous disconnection attempt was intentional + """ + return self._disconnection_cause + + def set_credentials(self, username: str, password: Optional[str] = None) -> None: + """ + Set a username and optionally a password for broker authentication. + + Must be called before .connect() to have any effect. + + :param str username: The username for broker authentication + :param str password: The password for broker authentication (Optional) + """ + self._mqtt_client.username_pw_set(username=username, password=password) + + def add_incoming_message_filter(self, topic: str) -> None: + """ + Filter incoming messages on a specific topic. + + :param str topic: The topic you wish to filter on + + :raises: ValueError if a filter is already applied for the topic + """ + if topic in self._incoming_filtered_messages: + raise ValueError("Filter already applied for this topic") + + # Add a Queue for this filter + self._incoming_filtered_messages[topic] = asyncio.Queue() + + def callback(client, userdata, message): + logger.debug("Incoming MQTT Message received on filter {}".format(message.topic)) + + async def add_to_queue(): + await self._incoming_filtered_messages[topic].put(message) + + asyncio.run_coroutine_threadsafe(add_to_queue(), self._event_loop) + + # Add the callback as a filter + self._mqtt_client.message_callback_add(topic, callback) + + def remove_incoming_message_filter(self, topic: str) -> None: + """ + Stop filtering incoming messages on a specific topic + + :param str topic: The topic you wish to stop filtering on + + :raises: ValueError if a filter is not already applied for the topic + """ + if topic not in self._incoming_filtered_messages: + raise ValueError("Filter not yet applied to this topic") + + # Remove the callback + self._mqtt_client.message_callback_remove(topic) + + # Delete the filter queue + del self._incoming_filtered_messages[topic] + + def get_incoming_message_generator( + self, filter_topic: Optional[str] = None + ) -> AsyncGenerator[mqtt.MQTTMessage, None]: + """ + Return a generator that yields incoming messages + + :param str filter_topic: The topic you wish to receive a generator for. + If not provided, will return a generator for non-filtered messages + + :raises: ValueError if a filter is not already applied for the given topic + + :returns: A generator that yields incoming messages + """ + if filter_topic is not None and filter_topic not in self._incoming_filtered_messages: + raise ValueError("No filter applied for given topic") + elif filter_topic is not None: + incoming_messages = self._incoming_filtered_messages[filter_topic] + else: + incoming_messages = self._incoming_messages + + async def message_generator() -> AsyncGenerator[mqtt.MQTTMessage, None]: + while True: + yield await incoming_messages.get() + + return message_generator() + + async def connect(self) -> None: + """ + Connect to the MQTT broker using details set at instantiation. + + :raises: MQTTConnectionFailedError if there is a failure connecting + """ + # Wait for permission to alter the connection + async with self._connection_lock: + # NOTE: It's not generally safe to use .is_connected() to determine what to do, since + # the value could change at any time. However, it IS safe to do so here. The only way + # to become connected is to invoke a Paho .connect() and wait for a success. Due to the + # fact that this is the only method that can invoke Paho's .connect(), it does not + # return until a response is received, and it is protected by the ConnectionLock, + # we can be sure that there can't be overlapping invocations of Paho .connect(). + # Thus, we know that the state will not be changing on us within this block. + if not self.is_connected(): + + # Start the reconnect daemon (if enabled and not already running) + # + # NOTE: We need to track if the daemon was started by this attempt to know if + # we should cancel it in the event of this attempt being cancelled. Cancelling + # a connect attempt should not cancel a pre-existing reconnect daemon. + # + # Consider the case where a connection is established with a daemon, and the + # connection is later lost. In between automatic reconnect attempts, the .connect() + # method is invoked manually - if that manual connect attempt is cancelled, we + # should not be cancelling the pre-existing reconnect daemon that is trying to + # re-establish the original connection. + if self._auto_reconnect and not self._reconnect_daemon: + self._reconnect_daemon = asyncio.create_task(self._reconnect_loop()) + reconnect_started_on_this_attempt = True + else: + reconnect_started_on_this_attempt = False + + try: + await self._do_connect() + except asyncio.CancelledError: + logger.debug("Connect attempt was cancelled") + logger.warning( + "The cancelled connect attempt may still complete as it is in-flight" + ) + if self._reconnect_daemon and reconnect_started_on_this_attempt: + logger.debug( + "Reconnect daemon was started with this connect attempt. Cancelling it." + ) + self._reconnect_daemon.cancel() + self._reconnect_daemon = None + + # NOTE: Because a connection could still complete after cancellation due to + # it being in flight, this means that it's possible a connection could be + # established without a running reconnect daemon, even if auto_reconnect + # is enabled. This could be remedied fairly easily if so desired, but I've + # chosen to leave it out for simplicity. + else: + logger.debug( + "Reconnect daemon was started on a previous connect. Leaving it alone." + ) + raise + finally: + # Pending operation is completed regardless of outcome + del self._pending_connect + self._pending_connect = None + + else: + logger.debug("Already connected!") + + async def _do_connect(self) -> None: + """Connect, start network loop, and wait for response""" + + # NOTE: we know this is safe because of the connection lock in the outer method + self._pending_connect = self._event_loop.create_future() + + # Paho Connect + logger.debug( + "Attempting connect to host {} using port {}...".format(self._hostname, self._port) + ) + try: + rc = await self._event_loop.run_in_executor( + None, + functools.partial( + self._mqtt_client.connect, + host=self._hostname, + port=self._port, + keepalive=self._keep_alive, + ), + ) + rc_msg = mqtt.error_string(rc) + logger.debug("Connect returned rc {} - {}".format(rc, rc_msg)) + # TODO: more specialization of errors to indicate which are/aren't retryable + except asyncio.CancelledError: + # Handled in outer method + raise + except Exception as e: + raise MQTTConnectionFailedError(message="Failure in Paho .connect()") from e + + if rc != mqtt.MQTT_ERR_SUCCESS: + # NOTE: This block should probably never execute. Paho's .connect() is + # supposed to only return success or raise an exception. + logger.warning("Unexpected rc {} from Paho .connect()".format(rc)) + # MQTTConnectionFailedError expects a connack rc, but this is a regular rc. + # So chain a regular mqtt exception into a connection mqtt exception. + try: + raise MQTTError(rc=rc) + except MQTTError as e: + raise MQTTConnectionFailedError(message="Unexpected Paho .connect() rc") from e + + # Start Paho network loop, and store the task. This task will complete upon disconnect + # whether due to invocation of .disconnect(), an unexpected network drop, or a connection + # failure (which Paho considers to be a disconnect) + # + # NOTE: If the connect attempt is cancelled, the network loop cannot be stopped. + # This is because when using .loop_forever(), the loop lifecycle is managed by Paho. + # It cannot be manually ended; one of the above termination conditions must be met. + # + # However, in the case of a cancelled connect, it's likely that none of those conditions + # will be met, and thus the network loop will persist. + # This is fine, since it'll eventually get cleaned up, as at the very least, a + # .disconnect() invocation is required for graceful exit, if not sooner. + # + # But, this does introduce a case where the network loop may already be running + # during a connect attempt due to a previously cancelled attempt, so make sure it isn't + # before trying to start it again. + # + # NOTE: This MUST be called after connecting - loop_forever requires a socket to have been + # already established. This is not true of other network loop APIs, but it is true of this + # one. + if not self._network_loop_running(): + logger.debug("Starting Paho network loop") + self._network_loop = self._event_loop.run_in_executor( + None, self._mqtt_client.loop_forever + ) + else: + logger.debug( + "Paho network loop was already running. Likely due to a previous cancellation." + ) + + # The result of the CONNACK is received via the pending connect Future + logger.debug("Waiting for connect response...") + rc = await self._pending_connect + if rc != mqtt.CONNACK_ACCEPTED: + # If the connect failed, the network loop will stop. + # Might take a moment though, so wait on the network loop completion before clearing + if self._network_loop is not None: + # This block should always execute. This condition is just to help the type checker. + logger.debug("Waiting for network loop to exit and clearing task") + await self._network_loop + self._network_loop = None + raise MQTTConnectionFailedError(rc=rc) + + async def disconnect(self) -> None: + """ + Disconnect from the MQTT broker. + + Ensure this is called for graceful exit. + """ + # Wait for permission to alter the connection + async with self._connection_lock: + + # We no longer wish to be connected + self._desire_connection = False + + # Cancel reconnection attempts + if self._reconnect_daemon: + logger.debug("Cancelling reconnect daemon") + self._reconnect_daemon.cancel() + self._reconnect_daemon = None + + # The network loop Future being present (running or not) indicates one of a few things: + # 1) We are connected + # 2) We were previously connected and the connection was lost + # 3) A connect attempt started the loop, and then was cancelled before connect finished + # In all of these cases, we need to invoke Paho's .disconnect() to clean up. + if self._network_loop: + + # Paho Disconnect + # NOTE: Paho disconnect shouldn't raise any exceptions + logger.debug("Attempting disconnect") + rc = await self._event_loop.run_in_executor(None, self._mqtt_client.disconnect) + rc_msg = mqtt.error_string(rc) + logger.debug("Disconnect returned rc {} - {}".format(rc, rc_msg)) + + if rc == mqtt.MQTT_ERR_SUCCESS: + # Wait for disconnection to complete + logger.debug("Waiting for disconnect to complete...") + async with self.disconnected_cond: + await self.disconnected_cond.wait_for(lambda: not self.is_connected()) + logger.debug("Waiting for network loop to exit and clearing task") + await self._network_loop + self._network_loop = None + # Wait slightly for tasks started by the on_disconnect handler to finish. + # This will prevent warnings. + # TODO: improve efficiency by being able to wait on something specific + await asyncio.sleep(0.02) + elif rc == mqtt.MQTT_ERR_NO_CONN: + # This happens when we disconnect while already disconnected. + # In this implementation, it should only happen if Paho's inner state + # indicates we would like to be connected, but we actually aren't. + # We still want to do this disconnect however, because doing so changes + # Paho's state to indicate we no longer wish to be connected. + logger.debug("Early disconnect return (Already disconnected)") + logger.debug("Clearing network loop task") + self._network_loop = None + logger.debug("Clearing previous disconnection cause") + self._disconnection_cause = None + else: + # This block should never execute + logger.warning( + "Unexpected rc {} from Paho .disconnect(). Doing nothing.".format(rc) + ) + + else: + logger.debug("Already disconnected!") + + async def subscribe(self, topic: str) -> None: + """ + Subscribe to a topic from the MQTT broker. + + :param str topic: a single string specifying the subscription topic to subscribe to + + :raises: ValueError if topic is None or has zero string length. + :raises: MQTTError if there is an error subscribing + :raises: CancelledError if network failure occurs while in-flight + """ + try: + mid = None + logger.debug("Attempting subscribe to topic {}".format(topic)) + # Using this lock postpones any code that runs in the on_subscribe callback that will + # be invoked on response, as the callback also uses the lock. This ensures that the + # result cannot be received before we have a Future created for the eventual result. + async with self._mid_tracker_lock: + (rc, mid) = await self._event_loop.run_in_executor( + None, functools.partial(self._mqtt_client.subscribe, topic=topic, qos=1) + ) + rc_msg = mqtt.error_string(rc) + logger.debug("Subscribe returned rc {} - {}".format(rc, rc_msg)) + if rc != mqtt.MQTT_ERR_SUCCESS: + if rc not in expected_subscribe_rc: + logger.warning("Unexpected rc {} from Paho .subscribe()".format(rc)) + raise MQTTError(rc) + + # Establish a pending subscribe + sub_done = self._event_loop.create_future() + self._pending_subs[mid] = sub_done + + logger.debug("Waiting for SUBACK for mid {}".format(mid)) + await sub_done + except asyncio.CancelledError: + if mid: + logger.debug("Subscribe for mid {} was cancelled".format(mid)) + else: + logger.debug("Subscribe was cancelled before mid was assigned") + raise + finally: + # Delete any pending operation (if it exists) + async with self._mid_tracker_lock: + if mid and mid in self._pending_subs: + del self._pending_subs[mid] + + async def unsubscribe(self, topic: str) -> None: + """ + Unsubscribe from a topic on the MQTT broker. + + :param str topic: a single string which is the subscription topic to unsubscribe from. + + :raises: ValueError if topic is None or has zero string length. + :raises: MQTTError if there is an error subscribing + :raises: CancelledError if network failure occurs while in-flight + """ + try: + mid = None + logger.debug("Attempting unsubscribe from topic {}".format(topic)) + # Using this lock postpones any code that runs in the on_unsubscribe callback that will + # be invoked on response, as the callback also uses the lock. This ensures that the + # result cannot be received before we have a Future created for the eventual result. + async with self._mid_tracker_lock: + (rc, mid) = await self._event_loop.run_in_executor( + None, functools.partial(self._mqtt_client.unsubscribe, topic=topic) + ) + rc_msg = mqtt.error_string(rc) + logger.debug("Unsubscribe returned rc {} - {}".format(rc, rc_msg)) + if rc != mqtt.MQTT_ERR_SUCCESS: + if rc not in expected_unsubscribe_rc: + logger.warning("Unexpected rc {} from Paho .unsubscribe()".format(rc)) + raise MQTTError(rc) + + # Establish a pending unsubscribe + unsub_done = self._event_loop.create_future() + self._pending_unsubs[mid] = unsub_done + + logger.debug("Waiting for UNSUBACK for mid {}".format(mid)) + await unsub_done + except asyncio.CancelledError: + if mid: + logger.debug("Unsubscribe for mid {} was cancelled".format(mid)) + else: + logger.debug("Unsubscribe was cancelled before mid was assigned") + raise + finally: + # Delete any pending operation (if it exists) + async with self._mid_tracker_lock: + if mid and mid in self._pending_unsubs: + del self._pending_unsubs[mid] + + async def publish(self, topic: str, payload: Union[str, bytes, int, float, None]) -> None: + """ + Send a message via the MQTT broker. + + :param str topic: topic: The topic that the message should be published on. + :param payload: The actual message to send. + :type payload: str, bytes, int, float or None + :param int qos: the desired quality of service level for the subscription. Defaults to 1. + + :raises: ValueError if topic is None or has zero string length + :raises: ValueError if topic contains a wildcard ("+") + :raises: ValueError if the length of the payload is greater than 268435455 bytes + :raises: TypeError if payload is not a valid type + :raises: MQTTError if there is an error publishing + """ + try: + mid = None + logger.debug("Attempting publish to topic {}".format(topic)) + logger.debug("Publish payload: {}".format(str(payload))) + # Using this lock postpones any code that runs in the on_publish callback that will + # be invoked on response, as the callback also uses the lock. This ensures that the + # result cannot be received before we have a Future created for the eventual result. + async with self._mid_tracker_lock: + message_info = await self._event_loop.run_in_executor( + None, + functools.partial( + self._mqtt_client.publish, topic=topic, payload=payload, qos=1 + ), + ) + mid = message_info.mid + rc_msg = mqtt.error_string(message_info.rc) + logger.debug("Publish returned rc {} - {}".format(message_info.rc, rc_msg)) + if message_info.rc == mqtt.MQTT_ERR_NO_CONN: + logger.debug("MQTT Client not connected - will publish upon next connect") + elif message_info.rc != mqtt.MQTT_ERR_SUCCESS: + if message_info.rc not in expected_publish_rc: + logger.warning( + "Unexpected rc {} from Paho .publish()".format(message_info.rc) + ) + raise MQTTError(message_info.rc) + + # Establish a pending publish + pub_done = self._event_loop.create_future() + self._pending_pubs[mid] = pub_done + + logger.debug("Waiting for PUBACK") + # NOTE: Yes, message_info has a method called 'wait_for_publish' which would simplify + # things, however it has strange behavior in the case of disconnection - it raises a + # RuntimeError. However, the publish actually persists and still will be sent upon a + # connection, even though the message_info will NEVER be able to be used to track it + # (even after connection established). + # So, alas, we do it the messy handler/Future way, same as with sub and unsub. + await pub_done + except asyncio.CancelledError: + if mid: + logger.debug("Publish for mid {} was cancelled".format(mid)) + logger.warning("The cancelled publish may still be delivered if it was in-flight") + else: + logger.debug("Publish was cancelled before mid was assigned") + raise + finally: + # Delete any pending operation (if it exists) + async with self._mid_tracker_lock: + if mid and mid in self._pending_pubs: + del self._pending_pubs[mid] diff --git a/azure-iot-device/azure/iot/device/mqtt_topic_iothub.py b/azure-iot-device/azure/iot/device/mqtt_topic_iothub.py new file mode 100644 index 000000000..edce31bbc --- /dev/null +++ b/azure-iot-device/azure/iot/device/mqtt_topic_iothub.py @@ -0,0 +1,286 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import urllib.parse +from typing import Optional, Union, Dict + +logger = logging.getLogger(__name__) + +# NOTE: Whenever using standard URL encoding via the urllib.parse.quote() API +# make sure to specify that there are NO safe values (e.g. safe=""). By default +# "/" is skipped in encoding, and that is not desirable. +# +# DO NOT use urllib.parse.quote_plus(), as it turns ' ' characters into '+', +# which is invalid for MQTT publishes. +# +# DO NOT use urllib.parse.unquote_plus(), as it turns '+' characters into ' ', +# which is also invalid. + + +# NOTE (Oct 2020): URL encoding policy is currently inconsistent in this module due to restrictions +# with the Hub, as Hub does not do URL decoding on most values. +# (see: https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT)). +# Currently, as much as possible is URL encoded while keeping in line with the policy outlined +# in the above linked wiki article. This is to say that Device ID and Module ID are never +# encoded, however other values are. By convention, it's probably fine to be encoding/decoding most +# values that are not Device ID or Module ID, since it won't make a difference in production as +# the narrow range of acceptable values for, say, status code, or request ID don't contain any +# characters that require URL encoding/decoding in the first place. Thus it doesn't break on Hub, +# but it's still done here as a client-side best practice - Hub may eventually be doing a new API +# that does correctly URL encode/decode all values, so it's not good to roll back more than +# is currently necessary to avoid errors. + + +def _get_topic_base(device_id: str, module_id: Optional[str] = None) -> str: + """ + return the string that is at the beginning of all topics for this + device/module + """ + + # NOTE: Neither Device ID nor Module ID should be URL encoded in a topic string. + # See the repo wiki article for details: + # https://github.com/Azure/azure-iot-sdk-python/wiki/URL-Encoding-(MQTT) + topic = "devices/" + str(device_id) + if module_id: + topic = topic + "/modules/" + str(module_id) + return topic + + +def get_c2d_topic_for_subscribe(device_id: str) -> str: + """ + :return: The topic for cloud to device messages.It is of the format + "devices//messages/devicebound/#" + """ + return _get_topic_base(device_id) + "/messages/devicebound/#" + + +def get_input_topic_for_subscribe(device_id: str, module_id: str) -> str: + """ + :return: The topic for input messages. It is of the format + "devices//modules//inputs/#" + """ + return _get_topic_base(device_id, module_id) + "/inputs/#" + + +def get_direct_method_request_topic_for_subscribe() -> str: + """ + :return: The topic for ALL incoming direct methods. It is of the format + "$iothub/methods/POST/#" + """ + return "$iothub/methods/POST/#" + + +def get_twin_response_topic_for_subscribe() -> str: + """ + :return: The topic for ALL incoming twin responses. It is of the format + "$iothub/twin/res/#" + """ + return "$iothub/twin/res/#" + + +def get_twin_patch_topic_for_subscribe() -> str: + """ + :return: The topic for ALL incoming twin patches. It is of the format + "$iothub/twin/PATCH/properties/desired/# + """ + return "$iothub/twin/PATCH/properties/desired/#" + + +def get_telemetry_topic_for_publish(device_id: str, module_id: Optional[str] = None) -> str: + """ + return the topic string used to publish telemetry + """ + return _get_topic_base(device_id, module_id) + "/messages/events/" + + +def get_direct_method_response_topic_for_publish(request_id: str, status: Union[str, int]) -> str: + """ + :return: The topic for publishing direct method responses. It is of the format + "$iothub/methods/res//?$rid=" + """ + return "$iothub/methods/res/{status}/?$rid={request_id}".format( + status=urllib.parse.quote(str(status), safe=""), + request_id=urllib.parse.quote(str(request_id), safe=""), + ) + + +def get_twin_request_topic_for_publish(request_id: str) -> str: + """ + :return: The topic for publishing a get twin request. It is of the format + "$iothub/twin/GET/?$rid=" + """ + return "$iothub/twin/GET/?$rid={request_id}".format( + request_id=urllib.parse.quote(str(request_id), safe="") + ) + + +def get_twin_patch_topic_for_publish(request_id: str) -> str: + """ + :return: The topic for publishing a twin patch. It is of the format + "$iothub/twin/PATCH/properties/reported?$rid=" + """ + return "$iothub/twin/PATCH/properties/reported/?$rid={request_id}".format( + request_id=urllib.parse.quote(str(request_id), safe="") + ) + + +def insert_message_properties_in_topic( + topic: str, + system_properties: Dict[str, str], + custom_properties: Dict[str, str], +) -> str: + """ + URI encode system and custom properties into a message topic. + + :param dict system_properties: A dictionary mapping system properties to their values + :param dict custom_properties: A dictionary mapping custom properties to their values. + :return: The modified topic containing the encoded properties + """ + if system_properties: + encoded_system_properties = urllib.parse.urlencode( + system_properties, quote_via=urllib.parse.quote + ) + topic += encoded_system_properties + if system_properties and custom_properties: + topic += "&" + if custom_properties: + encoded_custom_properties = urllib.parse.urlencode( + custom_properties, quote_via=urllib.parse.quote + ) + topic += encoded_custom_properties + return topic + + +def extract_properties_from_message_topic(topic: str) -> Dict[str, str]: + """ + Extract key=value pairs from an incoming message topic, returning them as a dictionary. + If a key has no matching value, the value will be set to empty string. + + :param str topic: The topic string + :returns: dictionary mapping keys to values. + """ + parts = topic.split("/") + # Input Message Topic + if len(parts) > 4 and parts[4] == "inputs": + if len(parts) > 6: + properties_string = parts[6] + else: + properties_string = "" + # C2D Message Topic + elif len(parts) > 3 and parts[3] == "devicebound": + if len(parts) > 4: + properties_string = parts[4] + else: + properties_string = "" + else: + raise ValueError("topic has incorrect format") + + return _extract_properties(properties_string) + + +def extract_name_from_direct_method_request_topic(topic: str) -> str: + """ + Extract the direct method name from the direct method topic. + Topics for direct methods are of the following format: + "$iothub/methods/POST/{method name}/?$rid={request id}" + + :param str topic: The topic string + :return: method name from topic string + """ + parts = topic.split("/") + if topic.startswith("$iothub/methods/POST") and len(parts) >= 4: + return urllib.parse.unquote(parts[3]) + else: + raise ValueError("topic has incorrect format") + + +def extract_request_id_from_direct_method_request_topic(topic: str) -> str: + """ + Extract the Request ID (RID) from the direct method topic. + Topics for direct methods are of the following format: + "$iothub/methods/POST/{method name}/?$rid={request id}" + + :param str topic: the topic string + :raises: ValueError if topic has incorrect format + :returns: request id from topic string + """ + parts = topic.split("/") + if topic.startswith("$iothub/methods/POST") and len(parts) >= 4: + properties = _extract_properties(topic.split("?")[1]) + rid = properties.get("$rid") + if not rid: + raise ValueError("No request id in topic") + return rid + else: + raise ValueError("topic has incorrect format") + + +def extract_status_code_from_twin_response_topic(topic: str) -> str: + """ + Extract the status code from the twin response topic. + Topics for twin response are in the following format: + "$iothub/twin/res/{status}/?$rid={rid}" + + :param str topic: The topic string + :raises: ValueError if the topic has incorrect format + :returns status code from topic string + """ + parts = topic.split("/") + if topic.startswith("$iothub/twin/res/") and len(parts) >= 4: + return urllib.parse.unquote(parts[3]) + else: + raise ValueError("topic has incorrect format") + + +def extract_request_id_from_twin_response_topic(topic: str) -> str: + """ + Extract the Request ID (RID) from the twin response topic. + Topics for twin response are in the following format: + "$iothub/twin/res/{status}/?$rid={rid}" + + :param str topic: The topic string + :raises: ValueError if topic has incorrect format + :returns: request id from topic string + """ + parts = topic.split("/") + if topic.startswith("$iothub/twin/res/") and len(parts) >= 4: + properties = _extract_properties(topic.split("?")[1]) + rid = properties.get("$rid") + if not rid: + raise ValueError("No request id in topic") + return rid + else: + raise ValueError("topic has incorrect format") + + +# NOTE: This is duplicated from mqtt_topic_provisioning. If changing, change there too. +# Consider putting this in a separate module at some point. +def _extract_properties(properties_str: str) -> Dict[str, str]: + """Return a dictionary of properties from a string in the format + {key1}={value1}&{key2}={value2}...&{keyn}={valuen} + + For extracting values corresponding to keys the following rules are followed:- + If there is a just a key with no "=", the value is an empty string + """ + d: Dict[str, str] = {} + if len(properties_str) == 0: + # There are no properties, return empty + return d + + kv_pairs = properties_str.split("&") + for entry in kv_pairs: + pair = entry.split("=") + key = urllib.parse.unquote(pair[0]) + if len(pair) > 1: + # Key/Value Pair + value = urllib.parse.unquote(pair[1]) + else: + # Key with no value -> value = None + value = "" + d[key] = value + + return d diff --git a/azure-iot-device/azure/iot/device/mqtt_topic_provisioning.py b/azure-iot-device/azure/iot/device/mqtt_topic_provisioning.py new file mode 100644 index 000000000..fb99bb1a3 --- /dev/null +++ b/azure-iot-device/azure/iot/device/mqtt_topic_provisioning.py @@ -0,0 +1,114 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import urllib.parse +from typing import Dict + +logger = logging.getLogger(__name__) + +# NOTE: Whenever using standard URL encoding via the urllib.parse.quote() API +# make sure to specify that there are NO safe values (e.g. safe=""). By default +# "/" is skipped in encoding, and that is not desirable. +# +# DO NOT use urllib.parse.quote_plus(), as it turns ' ' characters into '+', +# which is invalid for MQTT publishes. +# +# DO NOT use urllib.parse.unquote_plus(), as it turns '+' characters into ' ', +# which is also invalid. + + +def get_response_topic_for_subscribe() -> str: + """ + :return: The topic string used to subscribe for receiving registration responses from DPS. + It is of the format "$dps/registrations/res/#" + """ + return "$dps/registrations/res/#" + + +def get_register_topic_for_publish(request_id: str) -> str: + """ + :return: The topic string used to send a registration. It is of the format + "$dps/registrations/PUT/iotdps-register/?$rid= + """ + return "$dps/registrations/PUT/iotdps-register/?$rid={request_id}".format( + request_id=urllib.parse.quote(str(request_id), safe="") + ) + + +def get_status_query_topic_for_publish(request_id: str, operation_id: str) -> str: + """ + :return: The topic string used to send an operation status query. It is of the format + "$dps/registrations/GET/iotdps-get-operationstatus/?$rid=&operationId= + """ + return "$dps/registrations/GET/iotdps-get-operationstatus/?$rid={request_id}&operationId={operation_id}".format( + request_id=urllib.parse.quote(str(request_id), safe=""), + operation_id=urllib.parse.quote(str(operation_id), safe=""), + ) + + +def extract_properties_from_response_topic(topic: str) -> Dict[str, str]: + """Extract key/value pairs from the response topic, returning them as a dictionary. + If a key has no matching value, the value will be set to empty string. + + Topics for responses from DPS are of the following format: + $dps/registrations/res//?$=&=...&= + + :param topic: The topic string + :return: a dictionary of property keys mapped to property values. + """ + parts = topic.split("/") + if topic.startswith("$dps/registrations/res/") and len(parts) == 5: + properties_string = parts[4].split("?")[1] + return _extract_properties(properties_string) + else: + raise ValueError("topic has incorrect format") + + +def extract_status_code_from_response_topic(topic: str) -> str: + """ + Extract the status code from the response topic + + Topics for responses from DPS are of the following format: + $dps/registrations/res//?$=&=...&= + Extract the status code part from the topic. + :param topic: The topic string + :return: The status code from the DPS response topic, as a string + """ + parts = topic.split("/") + if topic.startswith("$dps/registrations/res/") and len(parts) >= 4: + return urllib.parse.unquote(parts[3]) + else: + raise ValueError("topic has incorrect format") + + +# NOTE: This is duplicated from mqtt_topic_iothub. If changing, change there too. +# Consider putting this in a separate module at some point. +def _extract_properties(properties_str: str) -> Dict[str, str]: + """Return a dictionary of properties from a string in the format + {key1}={value1}&{key2}={value2}...&{keyn}={valuen} + + For extracting values corresponding to keys the following rules are followed:- + If there is a just a key with no "=", the value is an empty string + """ + d: Dict[str, str] = {} + if len(properties_str) == 0: + # There are no properties, return empty + return d + + kv_pairs = properties_str.split("&") + for entry in kv_pairs: + pair = entry.split("=") + key = urllib.parse.unquote(pair[0]) + if len(pair) > 1: + # Key/Value Pair + value = urllib.parse.unquote(pair[1]) + else: + # Key with no value -> value = None + value = "" + d[key] = value + + return d diff --git a/azure-iot-device/azure/iot/device/patch.py b/azure-iot-device/azure/iot/device/patch.py deleted file mode 100644 index 5d4bec270..000000000 --- a/azure-iot-device/azure/iot/device/patch.py +++ /dev/null @@ -1,167 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module provides patches used to dynamically modify items from the libraries""" - -import inspect -import logging - -logger = logging.getLogger(__name__) - -# This dict will be used as a scope for imports and defs in add_shims_for_inherited_methods -# in order to keep them out of the global scope of this module. -shim_scope = {} - - -def add_shims_for_inherited_methods(target_class): - """Dynamically add overriding, pass-through shim methods for all public inherited methods - on a child class, which simply call into the parent class implementation of the same method. - - These shim methods will include the same docstrings as the method from the parent class. - - This currently only works for Python 3.5+ - - Using DEBUG logging will allow you to see output of all dynamic operations that occur within - for debugging purposes. - - :param target_class: The child class to add shim methods to - """ - - # Depending on how the method was defined, it could be either a function or a method. - # Thus we need to find the union of the two sets. - # Here instance methods are considered functions because they are not yet bound to an instance - # of the class. Classmethods on the other hand, are already bound, and show up as methods. - # It also is worth noting that async functions/methods ARE picked up by this introspection. - class_functions = inspect.getmembers(target_class, predicate=inspect.isfunction) - class_methods = inspect.getmembers(target_class, predicate=inspect.ismethod) - all_methods = class_functions + class_methods - - # This list of attributes gives us a lot of information, but we only are using it to get - # the defining class of a given method. - class_attributes = inspect.classify_class_attrs(target_class) - - # We must alias class names to prevent naming collisions when this fn is called multiple times - # with classes that share a name. If we've already used this classname, add trailing underscore(s) - classname_alias = target_class.__name__ - while classname_alias in shim_scope: - classname_alias += "_" - - # Import the class we're adding methods to, so that functions defined in this scope can use super() - class_module = inspect.getmodule(target_class) - import_cmdstr = "from {module} import {target_class} as {alias}".format( - module=class_module.__name__, target_class=target_class.__name__, alias=classname_alias - ) - logger.debug("exec: " + import_cmdstr) - # exec(import_cmdstr, shim_scope) - - for method in all_methods: - method_name = method[0] - method_obj = method[1] - # We can index on 0 here because the list comprehension will always be exactly 1 element - method_attribute = [att for att in class_attributes if att.name == method_name][0] - # The object of the class where the method was originally defined. - originating_class_obj = method_attribute.defining_class - - # Create a shim method for all public methods inherited from a parent class - if method_name[0] != "_" and originating_class_obj != target_class: - - method_sig = inspect.signature(method_obj) - sig_params = method_sig.parameters - - # Bound methods (i.e. classmethods) remove the first parameter (i.e. cls) - # so we need to add it back - if inspect.ismethod(method_obj): - complete_params = [] - complete_params.append( - inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD) - ) - complete_params += list(sig_params.values()) - method_sig = method_sig.replace(parameters=complete_params) - - # Since neither "self" nor "cls" are used in invocation, we need to create a new - # invocation signature without them - invoke_params_list = [] - for param in sig_params.values(): - if param.name != "self" and param.name != "cls": - # Set the parameter to empty (since we use this in an invocation, not a signature) - new_param = param.replace(default=inspect.Parameter.empty) - invoke_params_list.append(new_param) - invoke_params = method_sig.replace(parameters=invoke_params_list) - - # Choose syntactical variants - if inspect.ismethod(method_obj): - obj_or_type = "cls" # Use 'cls' to invoke super() for classmethods - else: - obj_or_type = "self" # Use 'self' to invoke super() for instance methods - if inspect.iscoroutine(method_obj) or inspect.iscoroutinefunction(method_obj): - def_syntax = "async def" # Define coroutine function/method - ret_syntax = "return await" - else: - def_syntax = "def" # Define function/method - ret_syntax = "return" - - # Dynamically define a new shim function, with the same name, that invokes the method of the parent class - fn_def_cmdstr = "{def_syntax} {method_name}{signature}: {ret_syntax} super({leaf_class}, {object_or_type}).{method_name}{invocation}".format( - def_syntax=def_syntax, - method_name=method_name, - signature=str(method_sig), - ret_syntax=ret_syntax, - leaf_class=classname_alias, - object_or_type=obj_or_type, - invocation=str(invoke_params), - ) - logger.debug("exec: " + fn_def_cmdstr) - # exec(fn_def_cmdstr, shim_scope) - - # Copy the docstring from the method to the shim function - set_doc_cmdstr = "{method_name}.__doc__ = {leaf_class}.{method_name}.__doc__".format( - method_name=method_name, leaf_class=classname_alias - ) - logger.debug("exec: " + set_doc_cmdstr) - # exec(set_doc_cmdstr, shim_scope) - - # Add shim function to leaf/child class as a classmethod if the method being shimmed is a classmethod - if inspect.ismethod(method_obj): - attach_shim_cmdstr = ( - "setattr({leaf_class}, '{method_name}', classmethod({method_name}))".format( - leaf_class=classname_alias, method_name=method_name - ) - ) - # Add shim function to leaf/child class as a method if the method being shimmed is an instance method - else: - attach_shim_cmdstr = "setattr({leaf_class}, '{method_name}', {method_name})".format( - leaf_class=classname_alias, method_name=method_name - ) - logger.debug("exec: " + attach_shim_cmdstr) - # exec(attach_shim_cmdstr, shim_scope) - - # NOTE: the __qualname__ attributes of these new shim methods are merely the method name, - # rather than ., due to the scoping of the definition. - # This shouldn't matter, but in case it does, I am documenting that fact here. - - # For properties, we have a different strategy. While with methods we dynamically created - # redefinitions for each inherited method that called the parent class' implementation of the - # method, here we simply set the inherited property attribute directly onto the child. - # This will carry over all docstrings implicitly. - # We do this because properties, while defined syntactically via methods, actually are not - # methods directly on a class, but form a "property" object, which itself contains the - # get and set logic. Thus our strategy for methods can't really work here. - class_properties = inspect.getmembers(target_class, predicate=inspect.isdatadescriptor) - for prop in class_properties: - property_name = prop[0] - # We can index on 0 here because the list comprehension will always be exactly 1 element - property_attribute = [att for att in class_attributes if att.name == property_name][0] - # The object of the class where the property was originally defined. - originating_class_obj = property_attribute.defining_class - - # Simply redefine the same property on the leaf class if it was defined on a parent - if property_name[0] != "_" and originating_class_obj != target_class: - attach_property_cmdstr = ( - "setattr({leaf_class}, '{property_name}', {leaf_class}.{property_name})".format( - leaf_class=classname_alias, property_name=property_name - ) - ) - logger.debug("exec: " + attach_property_cmdstr) - # exec(attach_property_cmdstr, shim_scope) diff --git a/azure-iot-device/azure/iot/device/patch_documentation.py b/azure-iot-device/azure/iot/device/patch_documentation.py deleted file mode 100644 index 3b0ab9f48..000000000 --- a/azure-iot-device/azure/iot/device/patch_documentation.py +++ /dev/null @@ -1,314 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module provides hard coded patches used to modify items from the libraries. -Currently we have to do like this so that we don't use exec anywhere""" - - -def execute_patch_for_sync(): - from azure.iot.device.iothub.sync_clients import IoTHubDeviceClient as IoTHubDeviceClient - - def connect(self): - return super(IoTHubDeviceClient, self).connect() - - connect.__doc__ = IoTHubDeviceClient.connect.__doc__ - setattr(IoTHubDeviceClient, "connect", connect) - - def disconnect(self): - return super(IoTHubDeviceClient, self).disconnect() - - disconnect.__doc__ = IoTHubDeviceClient.disconnect.__doc__ - setattr(IoTHubDeviceClient, "disconnect", disconnect) - - def get_twin(self): - return super(IoTHubDeviceClient, self).get_twin() - - get_twin.__doc__ = IoTHubDeviceClient.get_twin.__doc__ - setattr(IoTHubDeviceClient, "get_twin", get_twin) - - def patch_twin_reported_properties(self, reported_properties_patch): - return super(IoTHubDeviceClient, self).patch_twin_reported_properties( - reported_properties_patch - ) - - patch_twin_reported_properties.__doc__ = ( - IoTHubDeviceClient.patch_twin_reported_properties.__doc__ - ) - setattr(IoTHubDeviceClient, "patch_twin_reported_properties", patch_twin_reported_properties) - - def receive_method_request(self, method_name=None, block=True, timeout=None): - return super(IoTHubDeviceClient, self).receive_method_request(method_name, block, timeout) - - receive_method_request.__doc__ = IoTHubDeviceClient.receive_method_request.__doc__ - setattr(IoTHubDeviceClient, "receive_method_request", receive_method_request) - - def receive_twin_desired_properties_patch(self, block=True, timeout=None): - return super(IoTHubDeviceClient, self).receive_twin_desired_properties_patch(block, timeout) - - receive_twin_desired_properties_patch.__doc__ = ( - IoTHubDeviceClient.receive_twin_desired_properties_patch.__doc__ - ) - setattr( - IoTHubDeviceClient, - "receive_twin_desired_properties_patch", - receive_twin_desired_properties_patch, - ) - - def send_message(self, message): - return super(IoTHubDeviceClient, self).send_message(message) - - send_message.__doc__ = IoTHubDeviceClient.send_message.__doc__ - setattr(IoTHubDeviceClient, "send_message", send_message) - - def send_method_response(self, method_response): - return super(IoTHubDeviceClient, self).send_method_response(method_response) - - send_method_response.__doc__ = IoTHubDeviceClient.send_method_response.__doc__ - setattr(IoTHubDeviceClient, "send_method_response", send_method_response) - - def shutdown(self): - return super(IoTHubDeviceClient, self).shutdown() - - shutdown.__doc__ = IoTHubDeviceClient.shutdown.__doc__ - setattr(IoTHubDeviceClient, "shutdown", shutdown) - - def update_sastoken(self, sastoken): - return super(IoTHubDeviceClient, self).update_sastoken(sastoken) - - update_sastoken.__doc__ = IoTHubDeviceClient.update_sastoken.__doc__ - setattr(IoTHubDeviceClient, "update_sastoken", update_sastoken) - - def create_from_connection_string(cls, connection_string, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_connection_string( - connection_string, **kwargs - ) - - create_from_connection_string.__doc__ = IoTHubDeviceClient.create_from_connection_string.__doc__ - setattr( - IoTHubDeviceClient, - "create_from_connection_string", - classmethod(create_from_connection_string), - ) - - def create_from_sastoken(cls, sastoken, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_sastoken(sastoken, **kwargs) - - create_from_sastoken.__doc__ = IoTHubDeviceClient.create_from_sastoken.__doc__ - setattr(IoTHubDeviceClient, "create_from_sastoken", classmethod(create_from_sastoken)) - - def create_from_symmetric_key(cls, symmetric_key, hostname, device_id, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_symmetric_key( - symmetric_key, hostname, device_id, **kwargs - ) - - create_from_symmetric_key.__doc__ = IoTHubDeviceClient.create_from_symmetric_key.__doc__ - setattr(IoTHubDeviceClient, "create_from_symmetric_key", classmethod(create_from_symmetric_key)) - - def create_from_x509_certificate(cls, x509, hostname, device_id, **kwargs): - return super(IoTHubDeviceClient, cls).create_from_x509_certificate( - x509, hostname, device_id, **kwargs - ) - - create_from_x509_certificate.__doc__ = IoTHubDeviceClient.create_from_x509_certificate.__doc__ - setattr( - IoTHubDeviceClient, - "create_from_x509_certificate", - classmethod(create_from_x509_certificate), - ) - setattr(IoTHubDeviceClient, "connected", IoTHubDeviceClient.connected) - setattr( - IoTHubDeviceClient, "on_background_exception", IoTHubDeviceClient.on_background_exception - ) - setattr( - IoTHubDeviceClient, - "on_connection_state_change", - IoTHubDeviceClient.on_connection_state_change, - ) - setattr(IoTHubDeviceClient, "on_message_received", IoTHubDeviceClient.on_message_received) - setattr( - IoTHubDeviceClient, - "on_method_request_received", - IoTHubDeviceClient.on_method_request_received, - ) - setattr( - IoTHubDeviceClient, "on_new_sastoken_required", IoTHubDeviceClient.on_new_sastoken_required - ) - setattr( - IoTHubDeviceClient, - "on_twin_desired_properties_patch_received", - IoTHubDeviceClient.on_twin_desired_properties_patch_received, - ) - from azure.iot.device.iothub.sync_clients import IoTHubModuleClient as IoTHubModuleClient - - def connect(self): - return super(IoTHubModuleClient, self).connect() - - connect.__doc__ = IoTHubModuleClient.connect.__doc__ - setattr(IoTHubModuleClient, "connect", connect) - - def disconnect(self): - return super(IoTHubModuleClient, self).disconnect() - - disconnect.__doc__ = IoTHubModuleClient.disconnect.__doc__ - setattr(IoTHubModuleClient, "disconnect", disconnect) - - def get_twin(self): - return super(IoTHubModuleClient, self).get_twin() - - get_twin.__doc__ = IoTHubModuleClient.get_twin.__doc__ - setattr(IoTHubModuleClient, "get_twin", get_twin) - - def patch_twin_reported_properties(self, reported_properties_patch): - return super(IoTHubModuleClient, self).patch_twin_reported_properties( - reported_properties_patch - ) - - patch_twin_reported_properties.__doc__ = ( - IoTHubModuleClient.patch_twin_reported_properties.__doc__ - ) - setattr(IoTHubModuleClient, "patch_twin_reported_properties", patch_twin_reported_properties) - - def receive_method_request(self, method_name=None, block=True, timeout=None): - return super(IoTHubModuleClient, self).receive_method_request(method_name, block, timeout) - - receive_method_request.__doc__ = IoTHubModuleClient.receive_method_request.__doc__ - setattr(IoTHubModuleClient, "receive_method_request", receive_method_request) - - def receive_twin_desired_properties_patch(self, block=True, timeout=None): - return super(IoTHubModuleClient, self).receive_twin_desired_properties_patch(block, timeout) - - receive_twin_desired_properties_patch.__doc__ = ( - IoTHubModuleClient.receive_twin_desired_properties_patch.__doc__ - ) - setattr( - IoTHubModuleClient, - "receive_twin_desired_properties_patch", - receive_twin_desired_properties_patch, - ) - - def send_message(self, message): - return super(IoTHubModuleClient, self).send_message(message) - - send_message.__doc__ = IoTHubModuleClient.send_message.__doc__ - setattr(IoTHubModuleClient, "send_message", send_message) - - def send_method_response(self, method_response): - return super(IoTHubModuleClient, self).send_method_response(method_response) - - send_method_response.__doc__ = IoTHubModuleClient.send_method_response.__doc__ - setattr(IoTHubModuleClient, "send_method_response", send_method_response) - - def shutdown(self): - return super(IoTHubModuleClient, self).shutdown() - - shutdown.__doc__ = IoTHubModuleClient.shutdown.__doc__ - setattr(IoTHubModuleClient, "shutdown", shutdown) - - def update_sastoken(self, sastoken): - return super(IoTHubModuleClient, self).update_sastoken(sastoken) - - update_sastoken.__doc__ = IoTHubModuleClient.update_sastoken.__doc__ - setattr(IoTHubModuleClient, "update_sastoken", update_sastoken) - - def create_from_connection_string(cls, connection_string, **kwargs): - return super(IoTHubModuleClient, cls).create_from_connection_string( - connection_string, **kwargs - ) - - create_from_connection_string.__doc__ = IoTHubModuleClient.create_from_connection_string.__doc__ - setattr( - IoTHubModuleClient, - "create_from_connection_string", - classmethod(create_from_connection_string), - ) - - def create_from_edge_environment(cls, **kwargs): - return super(IoTHubModuleClient, cls).create_from_edge_environment(**kwargs) - - create_from_edge_environment.__doc__ = IoTHubModuleClient.create_from_edge_environment.__doc__ - setattr( - IoTHubModuleClient, - "create_from_edge_environment", - classmethod(create_from_edge_environment), - ) - - def create_from_sastoken(cls, sastoken, **kwargs): - return super(IoTHubModuleClient, cls).create_from_sastoken(sastoken, **kwargs) - - create_from_sastoken.__doc__ = IoTHubModuleClient.create_from_sastoken.__doc__ - setattr(IoTHubModuleClient, "create_from_sastoken", classmethod(create_from_sastoken)) - - def create_from_x509_certificate(cls, x509, hostname, device_id, module_id, **kwargs): - return super(IoTHubModuleClient, cls).create_from_x509_certificate( - x509, hostname, device_id, module_id, **kwargs - ) - - create_from_x509_certificate.__doc__ = IoTHubModuleClient.create_from_x509_certificate.__doc__ - setattr( - IoTHubModuleClient, - "create_from_x509_certificate", - classmethod(create_from_x509_certificate), - ) - setattr(IoTHubModuleClient, "connected", IoTHubModuleClient.connected) - setattr( - IoTHubModuleClient, "on_background_exception", IoTHubModuleClient.on_background_exception - ) - setattr( - IoTHubModuleClient, - "on_connection_state_change", - IoTHubModuleClient.on_connection_state_change, - ) - setattr(IoTHubModuleClient, "on_message_received", IoTHubModuleClient.on_message_received) - setattr( - IoTHubModuleClient, - "on_method_request_received", - IoTHubModuleClient.on_method_request_received, - ) - setattr( - IoTHubModuleClient, "on_new_sastoken_required", IoTHubModuleClient.on_new_sastoken_required - ) - setattr( - IoTHubModuleClient, - "on_twin_desired_properties_patch_received", - IoTHubModuleClient.on_twin_desired_properties_patch_received, - ) - from azure.iot.device.provisioning.provisioning_device_client import ( - ProvisioningDeviceClient as ProvisioningDeviceClient, - ) - - def create_from_symmetric_key( - cls, provisioning_host, registration_id, id_scope, symmetric_key, **kwargs - ): - return super(ProvisioningDeviceClient, cls).create_from_symmetric_key( - provisioning_host, registration_id, id_scope, symmetric_key, **kwargs - ) - - create_from_symmetric_key.__doc__ = ProvisioningDeviceClient.create_from_symmetric_key.__doc__ - setattr( - ProvisioningDeviceClient, - "create_from_symmetric_key", - classmethod(create_from_symmetric_key), - ) - - def create_from_x509_certificate( - cls, provisioning_host, registration_id, id_scope, x509, **kwargs - ): - return super(ProvisioningDeviceClient, cls).create_from_x509_certificate( - provisioning_host, registration_id, id_scope, x509, **kwargs - ) - - create_from_x509_certificate.__doc__ = ( - ProvisioningDeviceClient.create_from_x509_certificate.__doc__ - ) - setattr( - ProvisioningDeviceClient, - "create_from_x509_certificate", - classmethod(create_from_x509_certificate), - ) - setattr( - ProvisioningDeviceClient, - "provisioning_payload", - ProvisioningDeviceClient.provisioning_payload, - ) diff --git a/azure-iot-device/azure/iot/device/provisioning/__init__.py b/azure-iot-device/azure/iot/device/provisioning/__init__.py deleted file mode 100644 index 09d5fbcb0..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Azure Provisioning Device Library - -This library provides functionality that enables zero-touch, just-in-time provisioning to the right IoT hub without requiring -human intervention, enabling customers to provision millions of devices in a secure and scalable manner. - -""" -from .provisioning_device_client import ProvisioningDeviceClient -from .models import RegistrationResult - -__all__ = ["ProvisioningDeviceClient", "RegistrationResult"] diff --git a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py deleted file mode 100644 index 152dea1bd..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py +++ /dev/null @@ -1,255 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -This module provides an abstract interface representing clients which can communicate with the -Device Provisioning Service. -""" - -import abc -import logging -from azure.iot.device.provisioning import pipeline - -from azure.iot.device.common.auth import sastoken as st -from azure.iot.device.common import auth, handle_exceptions - -logger = logging.getLogger(__name__) - - -def _validate_kwargs(exclude=[], **kwargs): - """Helper function to validate user provided kwargs. - Raises TypeError if an invalid option has been provided""" - - valid_kwargs = [ - "server_verification_cert", - "gateway_hostname", - "websockets", - "cipher", - "proxy_options", - "sastoken_ttl", - "keep_alive", - ] - - for kwarg in kwargs: - if (kwarg not in valid_kwargs) or (kwarg in exclude): - raise TypeError("Unsupported keyword argument '{}'".format(kwarg)) - - -def validate_registration_id(reg_id): - if not (reg_id and reg_id.strip()): - raise ValueError("Registration Id can not be none, empty or blank.") - - -def _get_config_kwargs(**kwargs): - """Get the subset of kwargs which pertain the config object""" - valid_config_kwargs = [ - "server_verification_cert", - "gateway_hostname", - "websockets", - "cipher", - "proxy_options", - "keep_alive", - ] - - config_kwargs = {} - for kwarg in kwargs: - if kwarg in valid_config_kwargs: - config_kwargs[kwarg] = kwargs[kwarg] - return config_kwargs - - -def _form_sas_uri(id_scope, registration_id): - return "{id_scope}/registrations/{registration_id}".format( - id_scope=id_scope, registration_id=registration_id - ) - - -class AbstractProvisioningDeviceClient(abc.ABC): - """ - Super class for any client that can be used to register devices to Device Provisioning Service. - """ - - def __init__(self, pipeline): - """ - Initializes the provisioning client. - - NOTE: This initializer should not be called directly. - Instead, the class methods that start with `create_from_` should be used to create a - client object. - - :param pipeline: Instance of the provisioning pipeline object. - :type pipeline: :class:`azure.iot.device.provisioning.pipeline.MQTTPipeline` - """ - self._pipeline = pipeline - self._provisioning_payload = None - - self._pipeline.on_background_exception = handle_exceptions.handle_background_exception - - @classmethod - def create_from_symmetric_key( - cls, provisioning_host, registration_id, id_scope, symmetric_key, **kwargs - ): - """ - Create a client which can be used to run the registration of a device with provisioning service - using Symmetric Key authentication. - - :param str provisioning_host: Host running the Device Provisioning Service. - Can be found in the Azure portal in the Overview tab as the string Global device endpoint. - :param str registration_id: The registration ID used to uniquely identify a device in the - Device Provisioning Service. The registration ID is alphanumeric, lowercase string - and may contain hyphens. - :param str id_scope: The ID scope used to uniquely identify the specific provisioning - service the device will register through. The ID scope is assigned to a - Device Provisioning Service when it is created by the user and is generated by the - service and is immutable, guaranteeing uniqueness. - :param str symmetric_key: The key which will be used to create the shared access signature - token to authenticate the device with the Device Provisioning Service. By default, - the Device Provisioning Service creates new symmetric keys with a default length of - 32 bytes when new enrollments are saved with the Auto-generate keys option enabled. - Users can provide their own symmetric keys for enrollments by disabling this option - within 16 bytes and 64 bytes and in valid Base64 format. - - :param str server_verification_cert: Configuration Option. The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param str gateway_hostname: Configuration Option. The gateway hostname for the gateway - device. - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int keepalive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :raises: TypeError if given an unrecognized parameter. - - :returns: A ProvisioningDeviceClient instance which can register via Symmetric Key. - """ - validate_registration_id(registration_id) - # Ensure no invalid kwargs were passed by the user - _validate_kwargs(**kwargs) - - # Create SasToken - uri = _form_sas_uri(id_scope=id_scope, registration_id=registration_id) - signing_mechanism = auth.SymmetricKeySigningMechanism(key=symmetric_key) - token_ttl = kwargs.get("sastoken_ttl", 3600) - try: - sastoken = st.RenewableSasToken(uri, signing_mechanism, ttl=token_ttl) - except st.SasTokenError as e: - new_err = ValueError("Could not create a SasToken using the provided values") - new_err.__cause__ = e - raise new_err - - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.ProvisioningPipelineConfig( - hostname=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - sastoken=sastoken, - **config_kwargs - ) - - # Pipeline setup - mqtt_provisioning_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - - return cls(mqtt_provisioning_pipeline) - - @classmethod - def create_from_x509_certificate( - cls, provisioning_host, registration_id, id_scope, x509, **kwargs - ): - """ - Create a client which can be used to run the registration of a device with - provisioning service using X509 certificate authentication. - - :param str provisioning_host: Host running the Device Provisioning Service. Can be found in - the Azure portal in the Overview tab as the string Global device endpoint. - :param str registration_id: The registration ID used to uniquely identify a device in the - Device Provisioning Service. The registration ID is alphanumeric, lowercase string - and may contain hyphens. - :param str id_scope: The ID scope is used to uniquely identify the specific - provisioning service the device will register through. The ID scope is assigned to a - Device Provisioning Service when it is created by the user and is generated by the - service and is immutable, guaranteeing uniqueness. - :param x509: The x509 certificate, To use the certificate the enrollment object needs to - contain cert (either the root certificate or one of the intermediate CA certificates). - If the cert comes from a CER file, it needs to be base64 encoded. - :type x509: :class:`azure.iot.device.X509` - - :param str server_verification_cert: Configuration Option. The trusted certificate chain. - Necessary when using connecting to an endpoint which has a non-standard root of trust, - such as a protocol gateway. - :param str gateway_hostname: Configuration Option. The gateway hostname for the gateway - device. - :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT - over websockets. - :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in - "OpenSSL cipher list format" or as a list of cipher suite strings. - :type cipher: str or list(str) - :param proxy_options: Options for sending traffic through proxy servers. - :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :param int keepalive: Maximum period in seconds between communications with the - broker. If no other messages are being exchanged, this controls the - rate at which the client will send ping messages to the broker. - If not provided default value of 60 secs will be used. - :raises: TypeError if given an unrecognized parameter. - - :returns: A ProvisioningDeviceClient which can register via X509 client certificates. - """ - validate_registration_id(registration_id) - # Ensure no invalid kwargs were passed by the user - excluded_kwargs = ["sastoken_ttl"] - _validate_kwargs(exclude=excluded_kwargs, **kwargs) - - # Pipeline Config setup - config_kwargs = _get_config_kwargs(**kwargs) - pipeline_configuration = pipeline.ProvisioningPipelineConfig( - hostname=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - x509=x509, - **config_kwargs - ) - - # Pipeline setup - mqtt_provisioning_pipeline = pipeline.MQTTPipeline(pipeline_configuration) - - return cls(mqtt_provisioning_pipeline) - - @abc.abstractmethod - def register(self): - """ - Register the device with the Device Provisioning Service. - """ - pass - - @property - def provisioning_payload(self): - return self._provisioning_payload - - @provisioning_payload.setter - def provisioning_payload(self, provisioning_payload): - """ - Set the payload that will form the request payload in a registration request. - - :param provisioning_payload: The payload that can be supplied by the user. - :type provisioning_payload: This can be an object or dictionary or a string or an integer. - """ - self._provisioning_payload = provisioning_payload - - -def log_on_register_complete(result=None): - # This could be a failed/successful registration result from DPS - # or a error from polling machine. Response should be given appropriately - if result is not None: - if result.status == "assigned": - logger.info("Successfully registered with Provisioning Service") - else: # There be other statuses - logger.info("Failed registering with Provisioning Service") diff --git a/azure-iot-device/azure/iot/device/provisioning/aio/__init__.py b/azure-iot-device/azure/iot/device/provisioning/aio/__init__.py deleted file mode 100644 index ac7977b17..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/aio/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Azure IoT Provisioning Service SDK - Asynchronous - -This SDK provides asynchronous functionality for communicating with the Azure Provisioning Service -as a Device. -""" - -from .async_provisioning_device_client import ProvisioningDeviceClient - -__all__ = ["ProvisioningDeviceClient"] diff --git a/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py deleted file mode 100644 index acd79c533..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py +++ /dev/null @@ -1,106 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -This module contains user-facing asynchronous Provisioning Device Client for Azure Provisioning -Device SDK. This client uses Symmetric Key and X509 authentication to register devices with an -IoT Hub via the Device Provisioning Service. -""" - -import logging -from azure.iot.device.common import async_adapter -from azure.iot.device.provisioning.abstract_provisioning_device_client import ( - AbstractProvisioningDeviceClient, -) -from azure.iot.device.provisioning.abstract_provisioning_device_client import ( - log_on_register_complete, -) -from azure.iot.device.provisioning.pipeline import exceptions as pipeline_exceptions -from azure.iot.device import exceptions -from azure.iot.device.provisioning.pipeline import constant as dps_constant - -logger = logging.getLogger(__name__) - - -async def handle_result(callback): - try: - return await callback.completion() - except pipeline_exceptions.ConnectionDroppedError as e: - raise exceptions.ConnectionDroppedError("Lost connection to IoTHub") from e - except pipeline_exceptions.ConnectionFailedError as e: - raise exceptions.ConnectionFailedError("Could not connect to IoTHub") from e - except pipeline_exceptions.UnauthorizedError as e: - raise exceptions.CredentialError("Credentials invalid, could not connect") from e - except pipeline_exceptions.ProtocolClientError as e: - raise exceptions.ClientError("Error in the IoTHub client") from e - except pipeline_exceptions.OperationTimeout as e: - raise exceptions.OperationTimeout("Could not complete operation before timeout") from e - except pipeline_exceptions.PipelineNotRunning as e: - raise exceptions.ClientError("Client has already been shut down") from e - except Exception as e: - raise exceptions.ClientError("Unexpected failure") from e - - -class ProvisioningDeviceClient(AbstractProvisioningDeviceClient): - """ - Client which can be used to run the registration of a device with provisioning service - using Symmetric Key or X509 authentication. - """ - - async def register(self): - """ - Register the device with the provisioning service. - - Before returning the client will also disconnect from the provisioning service. - If a registration attempt is made while a previous registration is in progress it may - throw an error. - - Once the device is successfully registered, the client will no longer be operable. - - :returns: RegistrationResult indicating the result of the registration. - :rtype: :class:`azure.iot.device.RegistrationResult` - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - - """ - logger.info("Registering with Provisioning Service...") - - if not self._pipeline.responses_enabled[dps_constant.REGISTER]: - await self._enable_responses() - - register_async = async_adapter.emulate_async(self._pipeline.register) - register_complete = async_adapter.AwaitableCallback(return_arg_name="result") - await register_async(payload=self._provisioning_payload, callback=register_complete) - result = await handle_result(register_complete) - - log_on_register_complete(result) - - if result is not None and result.status == "assigned": - logger.debug("Beginning pipeline shutdown operation") - shutdown_async = async_adapter.emulate_async(self._pipeline.shutdown) - callback = async_adapter.AwaitableCallback() - await shutdown_async(callback=callback) - await handle_result(callback) - logger.debug("Completed pipeline shutdown operation") - - return result - - async def _enable_responses(self): - """Enable to receive responses from Device Provisioning Service.""" - logger.info("Enabling reception of response from Device Provisioning Service...") - subscribe_async = async_adapter.emulate_async(self._pipeline.enable_responses) - - subscription_complete = async_adapter.AwaitableCallback() - await subscribe_async(callback=subscription_complete) - await handle_result(subscription_complete) - - logger.info("Successfully subscribed to Device Provisioning Service to receive responses") diff --git a/azure-iot-device/azure/iot/device/provisioning/models/__init__.py b/azure-iot-device/azure/iot/device/provisioning/models/__init__.py deleted file mode 100644 index 27385514e..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/models/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Azure Provisioning Device Models - -This package provides object models for use within the Azure Provisioning Device SDK. -""" - -from .registration_result import RegistrationResult # noqa: F401 diff --git a/azure-iot-device/azure/iot/device/provisioning/models/registration_result.py b/azure-iot-device/azure/iot/device/provisioning/models/registration_result.py deleted file mode 100644 index 560d76720..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/models/registration_result.py +++ /dev/null @@ -1,119 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import json - - -class RegistrationResult(object): - """ - The final result of a completed or failed registration attempt - :ivar:request_id: The request id to which the response is being obtained - :ivar:operation_id: The id of the operation as returned by the registration request. - :ivar status: The status of the registration process as returned by the provisioning service. - Values can be "unassigned", "assigning", "assigned", "failed", "disabled" - :ivar registration_state : Details like device id, assigned hub , date times etc returned - from the provisioning service. - """ - - def __init__(self, operation_id, status, registration_state=None): - """ - :param operation_id: The id of the operation as returned by the initial registration request. - :param status: The status of the registration process. - Values can be "unassigned", "assigning", "assigned", "failed", "disabled" - :param registration_state : Details like device id, assigned hub , date times etc returned - from the provisioning service. - """ - self._operation_id = operation_id - self._status = status - self._registration_state = registration_state - - @property - def operation_id(self): - return self._operation_id - - @property - def status(self): - return self._status - - @property - def registration_state(self): - return self._registration_state - - def __str__(self): - return "\n".join([str(self.registration_state), self.status]) - - -class RegistrationState(object): - """ - The registration state regarding the device. - :ivar device_id: Desired device id for the provisioned device - :ivar assigned_hub: Desired IoT Hub to which the device is linked. - :ivar sub_status: Substatus for 'Assigned' devices. Possible values are - "initialAssignment", "deviceDataMigrated", "deviceDataReset" - :ivar created_date_time: Registration create date time (in UTC). - :ivar last_update_date_time: Last updated date time (in UTC). - :ivar etag: The entity tag associated with the resource. - """ - - def __init__( - self, - device_id=None, - assigned_hub=None, - sub_status=None, - created_date_time=None, - last_update_date_time=None, - etag=None, - payload=None, - ): - """ - :param device_id: Desired device id for the provisioned device - :param assigned_hub: Desired IoT Hub to which the device is linked. - :param sub_status: Substatus for 'Assigned' devices. Possible values are - "initialAssignment", "deviceDataMigrated", "deviceDataReset" - :param created_date_time: Registration create date time (in UTC). - :param last_update_date_time: Last updated date time (in UTC). - :param etag: The entity tag associated with the resource. - :param payload: The payload with which hub is responding - """ - self._device_id = device_id - self._assigned_hub = assigned_hub - self._sub_status = sub_status - self._created_date_time = created_date_time - self._last_update_date_time = last_update_date_time - self._etag = etag - self._response_payload = payload - - @property - def device_id(self): - return self._device_id - - @property - def assigned_hub(self): - return self._assigned_hub - - @property - def sub_status(self): - return self._sub_status - - @property - def created_date_time(self): - return self._created_date_time - - @property - def last_update_date_time(self): - return self._last_update_date_time - - @property - def etag(self): - return self._etag - - @property - def response_payload(self): - return json.dumps(self._response_payload, default=lambda o: o.__dict__, sort_keys=True) - - def __str__(self): - return "\n".join( - [self.device_id, self.assigned_hub, self.sub_status, self.response_payload] - ) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py deleted file mode 100644 index 09fd9d824..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Azure Provisioning Device Communication Pipeline - -This package provides pipeline for use with the Azure Provisioning Device SDK. - -INTERNAL USAGE ONLY -""" -from .mqtt_pipeline import MQTTPipeline # noqa: F401 -from .config import ProvisioningPipelineConfig # noqa: F401 diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py deleted file mode 100644 index c785c5bf3..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py +++ /dev/null @@ -1,29 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from azure.iot.device.common.pipeline.config import BasePipelineConfig - -logger = logging.getLogger(__name__) - - -class ProvisioningPipelineConfig(BasePipelineConfig): - """A class for storing all configurations/options for Provisioning clients in the Azure IoT Python Device Client Library.""" - - def __init__(self, hostname, registration_id, id_scope, **kwargs): - """Initializer for ProvisioningPipelineConfig which passes all unrecognized keyword-args down to BasePipelineConfig - to be evaluated. This stacked options setting is to allow for unique configuration options to exist between the - multiple clients, while maintaining a base configuration class with shared config options. - - :param str hostname: The hostname of the Provisioning hub instance to connect to - :param str registration_id: The device registration identity being provisioned - :param str id_scope: The identity of the provisioning service being used - """ - super().__init__(hostname=hostname, **kwargs) - - # Provisioning Connection Details - self.registration_id = registration_id - self.id_scope = id_scope diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/constant.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/constant.py deleted file mode 100644 index 6eb52a660..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/constant.py +++ /dev/null @@ -1,34 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -This module contains constants related to the pipeline package. -""" - -REGISTER = "register" -QUERY = "query" - -""" -Default interval for polling, to use in case service doesn't provide it to us. -""" -DEFAULT_POLLING_INTERVAL = 2 - -""" -Default timeout to use when communicating with the service -""" - -DEFAULT_TIMEOUT_INTERVAL = 30 - -SUBSCRIBE_TOPIC_PROVISIONING = "$dps/registrations/res/#" -""" -The first part of the topic string used for publishing. -The registration request id (rid) value is appended to this. -""" -PUBLISH_TOPIC_REGISTRATION = "$dps/registrations/PUT/iotdps-register/?$rid={}" -""" -The topic string used for publishing a query request. -This must be provided with the registration request id (rid) as well as the operation id -""" -PUBLISH_TOPIC_QUERYING = "$dps/registrations/GET/iotdps-get-operationstatus/?$rid={}&operationId={}" diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/exceptions.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/exceptions.py deleted file mode 100644 index c99b4f9b4..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/exceptions.py +++ /dev/null @@ -1,21 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module defines an exception surface, exposed as part of the pipeline API""" - -# For now, present relevant transport errors as part of the Pipeline API surface -# so that they do not have to be duplicated at this layer. -# OK TODO This mimics the IotHub Case. Both IotHub and Provisioning needs to change -from azure.iot.device.common.pipeline.pipeline_exceptions import * # noqa: F401, F403 -from azure.iot.device.common.transport_exceptions import ( # noqa: F401 - ConnectionFailedError, - ConnectionDroppedError, - # CT TODO: UnauthorizedError (the one from transport) should probably not surface out of - # the pipeline due to confusion with the higher level service UnauthorizedError. It - # should probably get turned into some other error instead (e.g. ConnectionFailedError). - # But for now, this is a stopgap. - UnauthorizedError, - ProtocolClientError, -) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_pipeline.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_pipeline.py deleted file mode 100644 index e410cd2c1..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_pipeline.py +++ /dev/null @@ -1,279 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from azure.iot.device.common.evented_callback import EventedCallback -from azure.iot.device.common.pipeline import ( - pipeline_stages_base, - pipeline_ops_base, - pipeline_stages_mqtt, - pipeline_exceptions, - pipeline_nucleus, -) -from azure.iot.device.provisioning.pipeline import ( - pipeline_stages_provisioning, - pipeline_stages_provisioning_mqtt, -) -from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning -from azure.iot.device.provisioning.pipeline import constant as provisioning_constants - -logger = logging.getLogger(__name__) - - -class MQTTPipeline(object): - def __init__(self, pipeline_configuration): - """ - Constructor for instantiating a pipeline - :param security_client: The security client which stores credentials - """ - self.responses_enabled = {provisioning_constants.REGISTER: False} - - # Event Handlers - Will be set by Client after instantiation of pipeline - self.on_connected = None - self.on_disconnected = None - self.on_background_exception = None - self.on_message_received = None - self._registration_id = pipeline_configuration.registration_id - - # Contains data and information shared globally within the pipeline - self._nucleus = pipeline_nucleus.PipelineNucleus(pipeline_configuration) - - self._pipeline = ( - # - # The root is always the root. By definition, it's the first stage in the pipeline. - # - pipeline_stages_base.PipelineRootStage(self._nucleus) - # - # SasTokenStage comes near the root by default because it should be as close - # to the top of the pipeline as possible, and does not need to be after anything. - # - .append_stage(pipeline_stages_base.SasTokenStage()) - # - # RegistrationStage needs to come early because this is the stage that converts registration - # or query requests into request and response objects which are used by later stages - # - .append_stage(pipeline_stages_provisioning.RegistrationStage()) - # - # PollingStatusStage needs to come after RegistrationStage because RegistrationStage counts - # on PollingStatusStage to poll until the registration is complete. - # - .append_stage(pipeline_stages_provisioning.PollingStatusStage()) - # - # CoordinateRequestAndResponseStage needs to be after RegistrationStage and PollingStatusStage - # because these 2 stages create the request ops that CoordinateRequestAndResponseStage - # is coordinating. It needs to be before ProvisioningMQTTTranslationStage because that stage - # operates on ops that CoordinateRequestAndResponseStage produces - # - .append_stage(pipeline_stages_base.CoordinateRequestAndResponseStage()) - # - # ProvisioningMQTTTranslationStage comes here because this is the point where we can translate - # all operations directly into MQTT. After this stage, only pipeline_stages_base stages - # are allowed because ProvisioningMQTTTranslationStage removes all the provisioning-ness from the ops - # - .append_stage(pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage()) - # - # AutoConnectStage comes here because only MQTT ops have the need_connection flag set - # and this is the first place in the pipeline where we can guarantee that all network - # ops are MQTT ops. - # - .append_stage(pipeline_stages_base.AutoConnectStage()) - # - # ConnectionStateStage needs to be after AutoConnectStage because the AutoConnectStage - # can create ConnectOperations and we (may) want to queue connection related operations - # in the ConnectionStateStage - # - .append_stage(pipeline_stages_base.ConnectionStateStage()) - # - # RetryStage needs to be near the end because it's retrying low-level MQTT operations. - # - .append_stage(pipeline_stages_base.RetryStage()) - # - # OpTimeoutStage needs to be after RetryStage because OpTimeoutStage returns the timeout - # errors that RetryStage is watching for. - # - .append_stage(pipeline_stages_base.OpTimeoutStage()) - # - # MQTTTransportStage needs to be at the very end of the pipeline because this is where - # operations turn into network traffic - # - .append_stage(pipeline_stages_mqtt.MQTTTransportStage()) - ) - - def _on_pipeline_event(event): - # error because no events should - logger.debug("Dropping unknown pipeline event {}".format(event.name)) - - def _on_connected(): - if self.on_connected: - self.on_connected("connected") - - def _on_disconnected(): - if self.on_disconnected: - self.on_disconnected("disconnected") - - def _on_background_exception(): - if self.on_background_exception: - self.on_background_exception - - self._pipeline.on_pipeline_event_handler = _on_pipeline_event - self._pipeline.on_connected_handler = _on_connected - self._pipeline.on_disconnected_handler = _on_disconnected - - callback = EventedCallback() - op = pipeline_ops_base.InitializePipelineOperation(callback=callback) - - self._pipeline.run_op(op) - callback.wait_for_completion() - - # Set the running flag - self._running = True - - def _verify_running(self): - if not self._running: - raise pipeline_exceptions.PipelineNotRunning( - "Cannot execute method - Pipeline is not running" - ) - - def shutdown(self, callback): - """Shut down the pipeline and clean up any resources. - - Once shut down, making any further calls on the pipeline will result in a - PipelineNotRunning exception being raised. - - There is currently no way to resume pipeline functionality once shutdown has occurred. - - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.PipelineNotRunning` if the - pipeline has already been shut down - - The shutdown process itself is not expected to fail under any normal condition, but if it - does, exceptions are not "raised", but rather, returned via the "error" parameter when - invoking "callback". - """ - self._verify_running() - logger.debug("Commencing shutdown of pipeline") - - def on_complete(op, error): - if not error: - # Only set the pipeline to not be running if the op was successful - self._running = False - callback(error=error) - - # NOTE: While we do run this operation, its functionality is incomplete. Some stages still - # need a response to this operation implemented. Additionally, there are other pipeline - # constructs other than Stages (e.g. Operations) which may have timers attached. These are - # lesser issues, but should be addressed at some point. - # TODO: Truly complete the shutdown implementation - self._pipeline.run_op(pipeline_ops_base.ShutdownPipelineOperation(callback=on_complete)) - - def connect(self, callback=None): - """ - Connect to the service. - - :param callback: callback which is called when the connection to the service is complete. - - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.PipelineNotRunning` if the - pipeline has already been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.UnauthorizedError` - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ProtocolClientError` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - logger.debug("connect called") - - def pipeline_callback(op, error): - callback(error=error) - - self._pipeline.run_op(pipeline_ops_base.ConnectOperation(callback=pipeline_callback)) - - def disconnect(self, callback=None): - """ - Disconnect from the service. - - :param callback: callback which is called when the connection to the service has been disconnected - - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.PipelineNotRunning` if the - pipeline has already been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` - """ - self._verify_running() - logger.debug("disconnect called") - - def pipeline_callback(op, error): - callback(error=error) - - self._pipeline.run_op(pipeline_ops_base.DisconnectOperation(callback=pipeline_callback)) - - # NOTE: Currently, this operation will retry itself indefinitely in the case of timeout - def enable_responses(self, callback=None): - """ - Enable response from the DPS service by subscribing to the appropriate topics. - - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.PipelineNotRunning` if the - pipeline has already been shut down - - :param callback: callback which is called when responses are enabled - """ - self._verify_running() - logger.debug("enable_responses called") - - self.responses_enabled[provisioning_constants.REGISTER] = True - - def pipeline_callback(op, error): - callback(error=error) - - self._pipeline.run_op( - pipeline_ops_base.EnableFeatureOperation( - feature_name=provisioning_constants.REGISTER, callback=pipeline_callback - ) - ) - - def register(self, payload=None, callback=None): - """ - Register to the device provisioning service. - :param payload: Payload that can be sent with the registration request. - :param callback: callback which is called when the registration is done. - - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.PipelineNotRunning` if the - pipeline has already been shut down - - The following exceptions are not "raised", but rather returned via the "error" parameter - when invoking "callback": - - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.NoConnectionError` - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ProtocolClientError - - The following exceptions can be returned via the "error" parameter only if auto-connect - is enabled in the pipeline configuration: - - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionFailedError` - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionDroppedError` - :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.UnauthorizedError`` - :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.OperationTimeout` - """ - self._verify_running() - - def on_complete(op, error): - # TODO : Apparently when its failed we can get result as well as error. - if error: - callback(error=error, result=None) - else: - callback(result=op.registration_result) - - self._pipeline.run_op( - pipeline_ops_provisioning.RegisterOperation( - request_payload=payload, registration_id=self._registration_id, callback=on_complete - ) - ) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_topic_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_topic_provisioning.py deleted file mode 100644 index 38660c6ad..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_topic_provisioning.py +++ /dev/null @@ -1,121 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import urllib - -logger = logging.getLogger(__name__) - -# NOTE: Whenever using standard URL encoding via the urllib.parse.quote() API -# make sure to specify that there are NO safe values (e.g. safe=""). By default -# "/" is skipped in encoding, and that is not desirable. -# -# DO NOT use urllib.parse.quote_plus(), as it turns ' ' characters into '+', -# which is invalid for MQTT publishes. -# -# DO NOT use urllib.parse.unquote_plus(), as it turns '+' characters into ' ', -# which is also invalid. - - -def _get_topic_base(): - """ - return the string that creates the beginning of all topics for DPS - """ - return "$dps/registrations/" - - -def get_register_topic_for_subscribe(): - """ - :return: The topic string used to subscribe for receiving future responses from DPS. - It is of the format "$dps/registrations/res/#" - """ - return _get_topic_base() + "res/#" - - -def get_register_topic_for_publish(request_id): - """ - :return: The topic string used to send a registration. It is of the format - "$dps/registrations/PUT/iotdps-register/?$rid= - """ - return (_get_topic_base() + "PUT/iotdps-register/?$rid={request_id}").format( - request_id=urllib.parse.quote(str(request_id), safe="") - ) - - -def get_query_topic_for_publish(request_id, operation_id): - """ - :return: The topic string used to send a query. It is of the format - "$dps/registrations/GET/iotdps-get-operationstatus/?$rid=&operationId= - """ - return ( - _get_topic_base() - + "GET/iotdps-get-operationstatus/?$rid={request_id}&operationId={operation_id}" - ).format( - request_id=urllib.parse.quote(str(request_id), safe=""), - operation_id=urllib.parse.quote(str(operation_id), safe=""), - ) - - -def _get_topic_for_response(): - """ - return the topic string used to publish telemetry - """ - return _get_topic_base() + "res/" - - -def is_dps_response_topic(topic): - """ - Topics for responses from DPS are of the following format: - $dps/registrations/res//?$=&=...&= - :param topic: The topic string - """ - if _get_topic_for_response() in topic: - return True - return False - - -def extract_properties_from_dps_response_topic(topic): - """ - Topics for responses from DPS are of the following format: - $dps/registrations/res//?$=&=...&= - Extract key=value pairs from the latter part of the topic. - :param topic: The topic string - :return: a dictionary of property keys mapped to property values. - """ - topic_parts = topic.split("$") - properties = topic_parts[2] - - # NOTE: we cannot use urllib.parse.parse_qs because it always decodes '+' as ' ', - # and the behavior cannot be overridden. Must parse key/value pairs manually. - - if properties: - key_value_pairs = properties.split("&") - key_value_dict = {} - for entry in key_value_pairs: - pair = entry.split("=") - key = urllib.parse.unquote(pair[0]) - value = urllib.parse.unquote(pair[1]) - if key_value_dict.get(key): - raise ValueError("Duplicate keys in DPS response topic") - else: - key_value_dict[key] = value - - return key_value_dict - - -def extract_status_code_from_dps_response_topic(topic): - """ - Topics for responses from DPS are of the following format: - $dps/registrations/res//?$=&=...&= - Extract the status code part from the topic. - :param topic: The topic string - :return: The status code from the DPS response topic, as a string - """ - POS_STATUS_CODE_IN_TOPIC = 3 - topic_parts = topic.split("$") - url_parts = topic_parts[1].split("/") - status_code = url_parts[POS_STATUS_CODE_IN_TOPIC] - return urllib.parse.unquote(status_code) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py deleted file mode 100644 index 2a670cc0e..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py +++ /dev/null @@ -1,60 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.common.pipeline.pipeline_ops_base import PipelineOperation - - -class RegisterOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to send a registration request - to an Device Provisioning Service. - - This operation is in the group of DPS operations because it is very specific to the DPS client. - """ - - def __init__(self, request_payload, registration_id, callback, registration_result=None): - """ - Initializer for RegisterOperation objects. - - :param request_payload: The request that we are sending to the service - :param registration_id: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.request_payload = request_payload - self.registration_id = registration_id - self.registration_result = registration_result - self.retry_after_timer = None - self.polling_timer = None - self.provisioning_timeout_timer = None - - -class PollStatusOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to send a registration request - to an Device Provisioning Service. - - This operation is in the group of DPS operations because it is very specific to the DPS client. - """ - - def __init__(self, operation_id, request_payload, callback, registration_result=None): - """ - Initializer for PollStatusOperation objects. - - :param operation_id: The id of the existing operation for which the polling was started. - :param request_payload: The request that we are sending to the service - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super().__init__(callback=callback) - self.operation_id = operation_id - self.request_payload = request_payload - self.registration_result = registration_result - self.retry_after_timer = None - self.polling_timer = None - self.provisioning_timeout_timer = None diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py deleted file mode 100644 index 2f739d74b..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py +++ /dev/null @@ -1,461 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from azure.iot.device.common.pipeline import pipeline_ops_base, pipeline_thread -from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage -from . import pipeline_ops_provisioning -from azure.iot.device import exceptions -from azure.iot.device.provisioning.pipeline import constant -from azure.iot.device.provisioning.models.registration_result import ( - RegistrationResult, - RegistrationState, -) -import logging -import weakref -import json -from threading import Timer - -logger = logging.getLogger(__name__) - - -class CommonProvisioningStage(PipelineStage): - """ - This is a super stage that the RegistrationStage and PollingStatusStage of - provisioning would both use. It contains some common functions like decoding response - and retrieving error, retrieving registration status, retrieving operation id - and forming a complete result. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _clear_timeout_timer(self, op, error): - """ - Clearing timer for provisioning operations (Register and PollStatus) - when they respond back from service. - """ - if op.provisioning_timeout_timer: - logger.debug("{}({}): Cancelling provisioning timeout timer".format(self.name, op.name)) - op.provisioning_timeout_timer.cancel() - op.provisioning_timeout_timer = None - - @staticmethod - def _decode_response(provisioning_op): - return json.loads(provisioning_op.response_body.decode("utf-8")) - - @staticmethod - def _form_complete_result(operation_id, decoded_response, status): - """ - Create the registration result from the complete decoded json response for details regarding the registration process. - """ - decoded_state = decoded_response.get("registrationState", None) - registration_state = None - if decoded_state is not None: - registration_state = RegistrationState( - device_id=decoded_state.get("deviceId", None), - assigned_hub=decoded_state.get("assignedHub", None), - sub_status=decoded_state.get("substatus", None), - created_date_time=decoded_state.get("createdDateTimeUtc", None), - last_update_date_time=decoded_state.get("lastUpdatedDateTimeUtc", None), - etag=decoded_state.get("etag", None), - payload=decoded_state.get("payload", None), - ) - - registration_result = RegistrationResult( - operation_id=operation_id, status=status, registration_state=registration_state - ) - return registration_result - - def _process_service_error_status_code(self, original_provisioning_op, request_response_op): - logger.info( - "{stage_name}({op_name}): Received error with status code {status_code} for {prov_op_name} request operation".format( - stage_name=self.name, - op_name=request_response_op.name, - prov_op_name=request_response_op.request_type, - status_code=request_response_op.status_code, - ) - ) - logger.debug( - "{stage_name}({op_name}): Response body: {body}".format( - stage_name=self.name, - op_name=request_response_op.name, - body=request_response_op.response_body, - ) - ) - original_provisioning_op.complete( - error=exceptions.ServiceError( - "{prov_op_name} request returned a service error status code {status_code}".format( - prov_op_name=request_response_op.request_type, - status_code=request_response_op.status_code, - ) - ) - ) - - def _process_retry_status_code(self, error, original_provisioning_op, request_response_op): - retry_interval = ( - int(request_response_op.retry_after, 10) - if request_response_op.retry_after is not None - else constant.DEFAULT_POLLING_INTERVAL - ) - - self_weakref = weakref.ref(self) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def do_retry_after(): - this = self_weakref() - logger.info( - "{stage_name}({op_name}): retrying".format( - stage_name=this.name, op_name=request_response_op.name - ) - ) - original_provisioning_op.retry_after_timer.cancel() - original_provisioning_op.retry_after_timer = None - original_provisioning_op.completed = False - this.run_op(original_provisioning_op) - - logger.info( - "{stage_name}({op_name}): Op needs retry with interval {interval} because of {error}. Setting timer.".format( - stage_name=self.name, - op_name=request_response_op.name, - interval=retry_interval, - error=error, - ) - ) - - logger.debug("{}({}): Creating retry timer".format(self.name, request_response_op.name)) - original_provisioning_op.retry_after_timer = Timer(retry_interval, do_retry_after) - original_provisioning_op.retry_after_timer.start() - - @staticmethod - def _process_failed_and_assigned_registration_status( - error, - operation_id, - decoded_response, - registration_status, - original_provisioning_op, - request_response_op, - ): - complete_registration_result = CommonProvisioningStage._form_complete_result( - operation_id=operation_id, decoded_response=decoded_response, status=registration_status - ) - original_provisioning_op.registration_result = complete_registration_result - if registration_status == "failed": - error = exceptions.ServiceError( - "Query Status operation returned a failed registration status with a status code of {status_code}".format( - status_code=request_response_op.status_code - ) - ) - original_provisioning_op.complete(error=error) - - @staticmethod - def _process_unknown_registration_status( - registration_status, original_provisioning_op, request_response_op - ): - error = exceptions.ServiceError( - "Query Status Operation encountered an invalid registration status {status} with a status code of {status_code}".format( - status=registration_status, status_code=request_response_op.status_code - ) - ) - original_provisioning_op.complete(error=error) - - -class PollingStatusStage(CommonProvisioningStage): - """ - This stage is responsible for sending the query request once initial response - is received from the registration response. - Upon the receipt of the response this stage decides whether - to send another query request or complete the procedure. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_provisioning.PollStatusOperation): - query_status_op = op - self_weakref = weakref.ref(self) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def query_timeout(): - this = self_weakref() - logger.info( - "{stage_name}({op_name}): returning timeout error".format( - stage_name=this.name, op_name=op.name - ) - ) - query_status_op.complete( - error=( - exceptions.ServiceError( - "Operation timed out before provisioning service could respond for {op_type} operation".format( - op_type=constant.QUERY - ) - ) - ) - ) - - logger.debug("{}({}): Creating provisioning timeout timer".format(self.name, op.name)) - query_status_op.provisioning_timeout_timer = Timer( - constant.DEFAULT_TIMEOUT_INTERVAL, query_timeout - ) - query_status_op.provisioning_timeout_timer.start() - - def on_query_response(op, error): - self._clear_timeout_timer(query_status_op, error) - logger.debug( - "{stage_name}({op_name}): Received response with status code {status_code} for PollStatusOperation with operation id {oper_id}".format( - stage_name=self.name, - op_name=op.name, - status_code=op.status_code, - oper_id=op.query_params["operation_id"], - ) - ) - - if error: - logger.debug( - "{stage_name}({op_name}): Received error for {prov_op_name} operation".format( - stage_name=self.name, op_name=op.name, prov_op_name=op.request_type - ) - ) - query_status_op.complete(error=error) - - else: - if 300 <= op.status_code < 429: - self._process_service_error_status_code(query_status_op, op) - - elif op.status_code >= 429: - self._process_retry_status_code(error, query_status_op, op) - - else: - decoded_response = self._decode_response(op) - operation_id = decoded_response.get("operationId", None) - registration_status = decoded_response.get("status", None) - if registration_status == "assigning": - polling_interval = ( - int(op.retry_after, 10) - if op.retry_after is not None - else constant.DEFAULT_POLLING_INTERVAL - ) - self_weakref = weakref.ref(self) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def do_polling(): - this = self_weakref() - logger.info( - "{stage_name}({op_name}): retrying".format( - stage_name=this.name, op_name=op.name - ) - ) - query_status_op.polling_timer.cancel() - query_status_op.polling_timer = None - query_status_op.completed = False - this.run_op(query_status_op) - - logger.debug( - "{stage_name}({op_name}): Op needs retry with interval {interval} because of {error}. Setting timer.".format( - stage_name=self.name, - op_name=op.name, - interval=polling_interval, - error=error, - ) - ) - - logger.debug( - "{}({}): Creating polling timer".format(self.name, op.name) - ) - query_status_op.polling_timer = Timer(polling_interval, do_polling) - query_status_op.polling_timer.start() - - elif registration_status == "assigned" or registration_status == "failed": - self._process_failed_and_assigned_registration_status( - error=error, - operation_id=operation_id, - decoded_response=decoded_response, - registration_status=registration_status, - original_provisioning_op=query_status_op, - request_response_op=op, - ) - - else: - self._process_unknown_registration_status( - registration_status=registration_status, - original_provisioning_op=query_status_op, - request_response_op=op, - ) - - self.send_op_down( - pipeline_ops_base.RequestAndResponseOperation( - request_type=constant.QUERY, - method="GET", - resource_location="/", - query_params={"operation_id": query_status_op.operation_id}, - request_body=query_status_op.request_payload, - callback=on_query_response, - ) - ) - - else: - super()._run_op(op) - - -class RegistrationStage(CommonProvisioningStage): - """ - This is the first stage that decides converts a registration request - into a normal request and response operation. - Upon the receipt of the response this stage decides whether - to send another registration request or send a query request. - Depending on the status and result of the response - this stage may also complete the registration process. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_provisioning.RegisterOperation): - initial_register_op = op - self_weakref = weakref.ref(self) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def register_timeout(): - this = self_weakref() - logger.info( - "{stage_name}({op_name}): returning timeout error".format( - stage_name=this.name, op_name=op.name - ) - ) - initial_register_op.complete( - error=( - exceptions.ServiceError( - "Operation timed out before provisioning service could respond for {op_type} operation".format( - op_type=constant.REGISTER - ) - ) - ) - ) - - logger.debug("{}({}): Creating provisioning timeout timer".format(self.name, op.name)) - initial_register_op.provisioning_timeout_timer = Timer( - constant.DEFAULT_TIMEOUT_INTERVAL, register_timeout - ) - initial_register_op.provisioning_timeout_timer.start() - - def on_registration_response(op, error): - self._clear_timeout_timer(initial_register_op, error) - logger.debug( - "{stage_name}({op_name}): Received response with status code {status_code} for RegisterOperation".format( - stage_name=self.name, op_name=op.name, status_code=op.status_code - ) - ) - if error: - logger.info( - "{stage_name}({op_name}): Received error for {prov_op_name} operation".format( - stage_name=self.name, op_name=op.name, prov_op_name=op.request_type - ) - ) - initial_register_op.complete(error=error) - - else: - - if 300 <= op.status_code < 429: - self._process_service_error_status_code(initial_register_op, op) - - elif op.status_code >= 429: - self._process_retry_status_code(error, initial_register_op, op) - - else: - decoded_response = self._decode_response(op) - operation_id = decoded_response.get("operationId", None) - registration_status = decoded_response.get("status", None) - - if registration_status == "assigning": - self_weakref = weakref.ref(self) - - def copy_result_to_original_op(op, error): - logger.debug( - "Copying registration result from Query Status Op to Registration Op" - ) - initial_register_op.registration_result = op.registration_result - initial_register_op.error = error - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def do_query_after_interval(): - this = self_weakref() - initial_register_op.polling_timer.cancel() - initial_register_op.polling_timer = None - - logger.info( - "{stage_name}({op_name}): polling".format( - stage_name=this.name, op_name=op.name - ) - ) - - query_worker_op = initial_register_op.spawn_worker_op( - worker_op_type=pipeline_ops_provisioning.PollStatusOperation, - request_payload=" ", - operation_id=operation_id, - callback=copy_result_to_original_op, - ) - - self.send_op_down(query_worker_op) - - logger.debug( - "{stage_name}({op_name}): Op will transition into polling after interval {interval}. Setting timer.".format( - stage_name=self.name, - op_name=op.name, - interval=constant.DEFAULT_POLLING_INTERVAL, - ) - ) - - logger.debug( - "{}({}): Creating polling timer".format(self.name, op.name) - ) - initial_register_op.polling_timer = Timer( - constant.DEFAULT_POLLING_INTERVAL, do_query_after_interval - ) - initial_register_op.polling_timer.start() - - elif registration_status == "failed" or registration_status == "assigned": - self._process_failed_and_assigned_registration_status( - error=error, - operation_id=operation_id, - decoded_response=decoded_response, - registration_status=registration_status, - original_provisioning_op=initial_register_op, - request_response_op=op, - ) - - else: - self._process_unknown_registration_status( - registration_status=registration_status, - original_provisioning_op=initial_register_op, - request_response_op=op, - ) - - registration_payload = DeviceRegistrationPayload( - registration_id=initial_register_op.registration_id, - custom_payload=initial_register_op.request_payload, - ) - self.send_op_down( - pipeline_ops_base.RequestAndResponseOperation( - request_type=constant.REGISTER, - method="PUT", - resource_location="/", - request_body=registration_payload.get_json_string(), - callback=on_registration_response, - ) - ) - - else: - super()._run_op(op) - - -class DeviceRegistrationPayload(object): - """ - The class representing the payload that needs to be sent to the service. - """ - - def __init__(self, registration_id, custom_payload=None): - # This is not a convention to name variables in python but the - # DPS service spec needs the name to be exact for it to work - self.registrationId = registration_id - self.payload = custom_payload - - def get_json_string(self): - return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py deleted file mode 100644 index 75cd319e3..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py +++ /dev/null @@ -1,153 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import urllib -from azure.iot.device.common.pipeline import ( - pipeline_ops_base, - pipeline_ops_mqtt, - pipeline_events_mqtt, - pipeline_thread, - pipeline_events_base, - pipeline_exceptions, -) -from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage -from azure.iot.device.provisioning.pipeline import mqtt_topic_provisioning -from azure.iot.device import constant as pkg_constant -from . import constant as pipeline_constant -from azure.iot.device import user_agent - -logger = logging.getLogger(__name__) - - -class ProvisioningMQTTTranslationStage(PipelineStage): - """ - PipelineStage which converts other Provisioning pipeline operations into MQTT operations. This stage also - converts MQTT pipeline events into Provisioning pipeline events. - """ - - def __init__(self): - super().__init__() - self.action_to_topic = {} - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - - if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - - client_id = self.nucleus.pipeline_configuration.registration_id - query_param_seq = [ - ("api-version", pkg_constant.PROVISIONING_API_VERSION), - ("ClientVersion", user_agent.get_provisioning_user_agent()), - ] - username = "{id_scope}/registrations/{registration_id}/{query_params}".format( - id_scope=self.nucleus.pipeline_configuration.id_scope, - registration_id=self.nucleus.pipeline_configuration.registration_id, - query_params=urllib.parse.urlencode(query_param_seq, quote_via=urllib.parse.quote), - ) - - # Dynamically attach the derived MQTT values to the InitializePipelineOperation - # to be used later down the pipeline - op.username = username - op.client_id = client_id - - self.send_op_down(op) - - elif isinstance(op, pipeline_ops_base.RequestOperation): - if op.request_type == pipeline_constant.REGISTER: - topic = mqtt_topic_provisioning.get_register_topic_for_publish( - request_id=op.request_id - ) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, - topic=topic, - payload=op.request_body, - ) - self.send_op_down(worker_op) - elif op.request_type == pipeline_constant.QUERY: - topic = mqtt_topic_provisioning.get_query_topic_for_publish( - request_id=op.request_id, operation_id=op.query_params["operation_id"] - ) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, - topic=topic, - payload=op.request_body, - ) - self.send_op_down(worker_op) - else: - raise pipeline_exceptions.OperationError( - "RequestOperation request_type {} not supported".format(op.request_type) - ) - - elif isinstance(op, pipeline_ops_base.EnableFeatureOperation): - # The only supported feature is REGISTER - if not op.feature_name == pipeline_constant.REGISTER: - raise pipeline_exceptions.OperationError( - "Trying to enable/disable invalid feature - {}".format(op.feature_name) - ) - # Enabling for register gets translated into an MQTT subscribe operation - topic = mqtt_topic_provisioning.get_register_topic_for_subscribe() - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTSubscribeOperation, topic=topic - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_base.DisableFeatureOperation): - # The only supported feature is REGISTER - if not op.feature_name == pipeline_constant.REGISTER: - raise pipeline_exceptions.OperationError( - "Trying to enable/disable invalid feature - {}".format(op.feature_name) - ) - # Disabling a register response gets turned into an MQTT unsubscribe operation - topic = mqtt_topic_provisioning.get_register_topic_for_subscribe() - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.MQTTUnsubscribeOperation, topic=topic - ) - self.send_op_down(worker_op) - - else: - # All other operations get passed down - super()._run_op(op) - - @pipeline_thread.runs_on_pipeline_thread - def _handle_pipeline_event(self, event): - """ - Pipeline Event handler function to convert incoming MQTT messages into the appropriate DPS - events, based on the topic of the message - """ - if isinstance(event, pipeline_events_mqtt.IncomingMQTTMessageEvent): - topic = event.topic - - if mqtt_topic_provisioning.is_dps_response_topic(topic): - logger.debug( - "Received payload:{payload} on topic:{topic}".format( - payload=event.payload, topic=topic - ) - ) - key_values = mqtt_topic_provisioning.extract_properties_from_dps_response_topic( - topic - ) - retry_after = key_values.get("retry-after", None) - status_code = mqtt_topic_provisioning.extract_status_code_from_dps_response_topic( - topic - ) - request_id = key_values["rid"] - - self.send_event_up( - pipeline_events_base.ResponseEvent( - request_id=request_id, - status_code=int(status_code, 10), - response_body=event.payload, - retry_after=retry_after, - ) - ) - else: - logger.debug("Unknown topic: {} passing up to next handler".format(topic)) - self.send_event_up(event) - - else: - # all other messages get passed up - self.send_event_up(event) diff --git a/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py deleted file mode 100644 index 479f9e9ec..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py +++ /dev/null @@ -1,111 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -This module contains user-facing synchronous Provisioning Device Client for Azure Provisioning -Device SDK. This client uses Symmetric Key and X509 authentication to register devices with an -IoT Hub via the Device Provisioning Service. -""" -import logging -from azure.iot.device.common.evented_callback import EventedCallback -from .abstract_provisioning_device_client import AbstractProvisioningDeviceClient -from .abstract_provisioning_device_client import log_on_register_complete -from azure.iot.device.provisioning.pipeline import constant as dps_constant -from .pipeline import exceptions as pipeline_exceptions -from azure.iot.device import exceptions - - -logger = logging.getLogger(__name__) - - -def handle_result(callback): - try: - return callback.wait_for_completion() - except pipeline_exceptions.ConnectionDroppedError as e: - raise exceptions.ConnectionDroppedError("Lost connection to the provisioning server") from e - except pipeline_exceptions.ConnectionFailedError as e: - raise exceptions.ConnectionFailedError( - "Could not connect to the provisioning server" - ) from e - except pipeline_exceptions.UnauthorizedError as e: - raise exceptions.CredentialError("Credentials invalid, could not connect") from e - except pipeline_exceptions.ProtocolClientError as e: - raise exceptions.ClientError("Error in the provisioning client") from e - except pipeline_exceptions.OperationTimeout as e: - raise exceptions.OperationTimeout("Could not complete operation before timeout") from e - except pipeline_exceptions.PipelineNotRunning as e: - raise exceptions.ClientError("Client has already been shut down") from e - except Exception as e: - raise exceptions.ClientError("Unexpected failure") from e - - -class ProvisioningDeviceClient(AbstractProvisioningDeviceClient): - """ - Client which can be used to run the registration of a device with provisioning service - using Symmetric Key or X509 authentication. - """ - - def register(self): - """ - Register the device with the provisioning service - - This is a synchronous call, meaning that this function will not return until the - registration process has completed successfully or the attempt has resulted in a failure. - Before returning, the client will also disconnect from the provisioning service. - If a registration attempt is made while a previous registration is in progress it may - throw an error. - - Once the device is successfully registered, the client will no longer be operable. - - :returns: RegistrationResult indicating the result of the registration. - :rtype: :class:`azure.iot.device.RegistrationResult` - - :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid - and a connection cannot be established. - :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if establishing a - connection results in failure. - :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost - during execution. - :raises: :class:`azure.iot.device.exceptions.OperationTimeout` if the connection times out. - :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure - during execution. - """ - logger.info("Registering with Provisioning Service...") - - if not self._pipeline.responses_enabled[dps_constant.REGISTER]: - self._enable_responses() - - # Register - register_complete = EventedCallback(return_arg_name="result") - self._pipeline.register(payload=self._provisioning_payload, callback=register_complete) - result = handle_result(register_complete) - - log_on_register_complete(result) - - # Implicitly shut down the pipeline upon successful completion - if result is not None and result.status == "assigned": - logger.debug("Beginning pipeline shutdown operation") - shutdown_complete = EventedCallback() - self._pipeline.shutdown(callback=shutdown_complete) - handle_result(shutdown_complete) - logger.debug("Completed pipeline shutdown operation") - - return result - - def _enable_responses(self): - """Enable to receive responses from Device Provisioning Service. - - This is a synchronous call, meaning that this function will not return until the feature - has been enabled. - - """ - logger.info("Enabling reception of response from Device Provisioning Service...") - - subscription_complete = EventedCallback() - self._pipeline.enable_responses(callback=subscription_complete) - - handle_result(subscription_complete) - - logger.info("Successfully subscribed to Device Provisioning Service to receive responses") diff --git a/azure-iot-device/azure/iot/device/provisioning_mqtt_client.py b/azure-iot-device/azure/iot/device/provisioning_mqtt_client.py new file mode 100644 index 000000000..0e38dfd74 --- /dev/null +++ b/azure-iot-device/azure/iot/device/provisioning_mqtt_client.py @@ -0,0 +1,537 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import json +import logging +import urllib.parse +import uuid +from typing import Optional, TypeVar +from .custom_typing import ( + RegistrationResult, + RegistrationState, + DeviceRegistrationRequest, + JSONSerializable, +) +from . import config, constant, user_agent +from . import exceptions as exc +from . import request_response as rr +from . import mqtt_client as mqtt +from . import mqtt_topic_provisioning as mqtt_topic + +# TODO: update docstrings with correct class paths once repo structured better + +logger = logging.getLogger(__name__) + +DEFAULT_POLLING_INTERVAL: int = 2 +DEFAULT_RECONNECT_INTERVAL: int = 10 +DEFAULT_TIMEOUT_INTERVAL: int = 30 + +_T = TypeVar("_T") + + +class ProvisioningMQTTClient: + def __init__( + self, + client_config: config.ProvisioningClientConfig, + ) -> None: + """Instantiate the client + + :param client_config: The config object for the client + :type client_config: :class:`ProvisioningClientConfig` + """ + # Identity + self._registration_id = client_config.registration_id + self._username = _format_username( + id_scope=client_config.id_scope, + registration_id=self._registration_id, + ) + + # SAS (Optional) + self._sastoken_provider = client_config.sastoken_provider + + # MQTT Configuration + self._mqtt_client = _create_mqtt_client(self._registration_id, client_config) + + # Add filters for receive topics delivering data used internally + register_response_topic = mqtt_topic.get_response_topic_for_subscribe() + self._mqtt_client.add_incoming_message_filter(register_response_topic) + # NOTE: credentials are set upon `.start()` + + # Internal request/response infrastructure + self._request_ledger = rr.RequestLedger() + self._register_responses_enabled = False + + # Background Tasks (Will be set upon `.start()`) + self._process_dps_responses_task: Optional[asyncio.Task[None]] = None + + async def _enable_dps_responses(self) -> None: + """Enable receiving of registration or polling responses from device provisioning service""" + logger.debug("Enabling receive of responses from device provisioning service...") + topic = mqtt_topic.get_response_topic_for_subscribe() + await self._mqtt_client.subscribe(topic) + self._register_responses_enabled = True + logger.debug("Device provisioning service responses receive enabled") + + async def _process_dps_responses(self) -> None: + """Run indefinitely, matching responses from DPS with request ID""" + logger.debug("Starting the '_process_dps_responses' background task") + dps_response_topic = mqtt_topic.get_response_topic_for_subscribe() + dps_responses = self._mqtt_client.get_incoming_message_generator(dps_response_topic) + + async for mqtt_message in dps_responses: + try: + extracted_properties = mqtt_topic.extract_properties_from_response_topic( + mqtt_message.topic + ) + request_id = extracted_properties["$rid"] + status_code = int( + mqtt_topic.extract_status_code_from_response_topic(mqtt_message.topic) + ) + # NOTE: We don't know what the content of the body is until we match the rid, so don't + # do more than just decode it here - leave interpreting the string to the coroutine + # waiting for the response. + response_body = mqtt_message.payload.decode("utf-8") + logger.debug("Device provisioning response received (rid: {})".format(request_id)) + logger.debug("Response body is {}".format(response_body)) + response = rr.Response( + request_id=request_id, + status=status_code, + body=response_body, + properties=extracted_properties, + ) + except Exception as e: + logger.error( + "Unexpected error ({}) while translating device provisioning response. Dropping.".format( + e + ) + ) + # NOTE: In this situation the operation waiting for the response that we failed to + # receive will hang. This isn't the end of the world, since it can be cancelled, + # but if we really wanted to smooth this out, we could cancel the pending operation + # based on the request id (assuming getting the request id is not what failed). + # But for now, that's probably overkill, especially since this path ideally should + # never happen, because we would like to assume IoTHub isn't sending malformed data + continue + try: + await self._request_ledger.match_response(response) + except asyncio.CancelledError: + # NOTE: In Python 3.7 this isn't a BaseException, so we must catch and re-raise + raise + except KeyError: + # NOTE: This should only happen in edge cases involving cancellation of + # in-flight operations + logger.warning( + "Device provisioning response (rid: {}) does not match any request".format( + request_id + ) + ) + except Exception as e: + logger.error( + "Unexpected error ({}) while matching Device provisioning response (rid: {}). Dropping response".format( + e, request_id + ) + ) + + async def start(self) -> None: + """Start up the client. + + - Must be invoked before any other methods. + - If already started, will not (meaningfully) do anything. + """ + # Set credentials + if self._sastoken_provider: + logger.debug("Using SASToken as password") + password = str(self._sastoken_provider.get_current_sastoken()) + else: + logger.debug("No password used") + password = None + self._mqtt_client.set_credentials(self._username, password) + # Start background tasks + if not self._process_dps_responses_task: + self._process_dps_responses_task = asyncio.create_task(self._process_dps_responses()) + + async def stop(self) -> None: + """Stop the client. + + - Must be invoked when done with the client for graceful exit. + - If already stopped, will not do anything. + - Cannot be cancelled - if you try, the client will still fully shut down as much as + possible, although CancelledError will still be raised. + """ + cancelled_tasks = [] + logger.debug("Stopping ProvisioningMQTTClient...") + + if self._process_dps_responses_task: + logger.debug("Cancelling '_process_dps_responses' background task") + self._process_dps_responses_task.cancel() + cancelled_tasks.append(self._process_dps_responses_task) + self._process_dps_responses_task = None + + results = await asyncio.gather( + *cancelled_tasks, asyncio.shield(self.disconnect()), return_exceptions=True + ) + for result in results: + # NOTE: Need to specifically exclude asyncio.CancelledError because it is not a + # BaseException in Python 3.7 + if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError): + raise result + + async def connect(self) -> None: + """Connect to Device Provisioning Service + + :raises: MQTTConnectionFailedError if there is a failure connecting + """ + # Connect + logger.debug("Connecting to Device Provisioning Service...") + await self._mqtt_client.connect() + logger.debug("Connect succeeded") + + async def disconnect(self) -> None: + """Disconnect from Device Provisioning Service""" + logger.debug("Disconnecting from Device Provisioning Service...") + await self._mqtt_client.disconnect() + logger.debug("Disconnect succeeded") + + async def wait_for_disconnect(self) -> Optional[exc.MQTTConnectionDroppedError]: + """Block until disconnection and return the cause, if any + + :returns: An MQTTConnectionDroppedError if the connection was dropped, or None if the + connection was intentionally ended + :rtype: MQTTConnectionDroppedError or None + """ + async with self._mqtt_client.disconnected_cond: + await self._mqtt_client.disconnected_cond.wait_for(lambda: not self.connected) + return self._mqtt_client.previous_disconnection_cause() + + async def send_register(self, payload: JSONSerializable = None) -> RegistrationResult: + if not self._register_responses_enabled: + await self._enable_dps_responses() + register_request_id = str(uuid.uuid4()) + register_topic = mqtt_topic.get_register_topic_for_publish(request_id=register_request_id) + device_registration_request: DeviceRegistrationRequest = { + "registrationId": self._registration_id, + "payload": payload, + } + publish_payload = json.dumps(device_registration_request) + interval = 0 # Initially set to no sleep + register_response = None + + while True: + await asyncio.sleep(interval) + # Create request with existing request id + # It is either a new request or a re-triable request + request = await self._request_ledger.create_request(register_request_id) + try: + try: + # Send request to DPS + logger.debug( + "Sending register request to Device Provisioning Service... (rid: {})".format( + request.request_id + ) + ) + await self._mqtt_client.publish(register_topic, publish_payload) + except asyncio.CancelledError: + logger.warning( + "Attempt to send register request to Device Provisioning Service was cancelled while in flight." + "It may or may not have been received (rid: {})".format(request.request_id) + ) + raise + except Exception: + logger.error( + "Sending register request to Device Provisioning Service failed (rid: {})".format( + request.request_id + ) + ) + raise + + # Wait for a response from DPS + try: + logger.debug( + "Waiting to receive response for register request from Device Provisioning Service...(rid: {})".format( + request.request_id + ) + ) + # Include a timeout for receipt of response + register_response = await asyncio.wait_for( + request.get_response(), DEFAULT_TIMEOUT_INTERVAL + ) + except asyncio.TimeoutError as te: + logger.debug( + "Attempt to send register request to Device Provisioning Service " + "took more time than allowable limit while waiting for a response. If the response arrives, " + "it will be discarded (rid: {})".format(request.request_id) + ) + raise exc.ProvisioningServiceError( + "Device Provisioning Service timed out while waiting for response to the " + "register request...(rid: {}).".format(request.request_id) + ) from te + except asyncio.CancelledError: + logger.debug( + "Attempt to send register request to Device Provisioning Service " + "was cancelled while waiting for a response. If the response arrives, " + "it will be discarded (rid: {})".format(request.request_id) + ) + raise + finally: + # If an exception caused exit before a pending request could be matched with a response + # then manually delete to prevent leaks. + if request.request_id in self._request_ledger: + await self._request_ledger.delete_request(request.request_id) + if register_response: + if 300 <= register_response.status < 429: + raise exc.ProvisioningServiceError( + "Device Provisioning Service responded to the register request with a failed status - {}. The detailed error is {}.".format( + register_response.status, register_response.body + ) + ) + elif register_response.status >= 429: + # Process same request for retry again + if register_response.properties is not None: + retry_after = int(register_response.properties.get("retry-after", "0")) + logger.debug( + "Retrying register request after {} secs to Device Provisioning Service...(rid: {})".format( + retry_after, request.request_id + ) + ) + interval = retry_after + else: # happens when response.status 200-300 + logger.debug( + "Received response for register request from Device Provisioning Service " + "(rid: {})".format(request.request_id) + ) + decoded_dps_response = json.loads(register_response.body) + operation_id = decoded_dps_response.get("operationId", None) + registration_status = decoded_dps_response.get("status", None) + if registration_status == "assigning": + # Transition into polling + logger.debug( + "Transitioning to polling request to Device Provisioning Service..." + ) + return await self.send_polling(operation_id) + elif ( + registration_status == "assigned" or registration_status == "failed" + ): # breaking from while + decoded_dps_state = decoded_dps_response.get("registrationState", None) + registration_state: RegistrationState = { + "deviceId": decoded_dps_state.get("deviceId", None), + "assignedHub": decoded_dps_state.get("assignedHub", None), + "subStatus": decoded_dps_state.get("subStatus", None), + "createdDateTimeUtc": decoded_dps_state.get("createdDateTimeUtc", None), + "lastUpdatedDateTimeUtc": decoded_dps_state.get( + "lastUpdatedDateTimeUtc", None + ), + "etag": decoded_dps_state.get("etag", None), + "payload": decoded_dps_state.get("payload", None), + } + registration_result: RegistrationResult = { + "operationId": operation_id, + "status": registration_status, + "registrationState": registration_state, + } + return registration_result + else: + raise exc.ProvisioningServiceError( + "Device Provisioning Service responded to the register request with an invalid " + "registration status {} failed status - {}. The entire error response is {}".format( + registration_status, + register_response.status, + json.loads(register_response.body), + ) + ) + + async def send_polling(self, operation_id: str) -> RegistrationResult: + polling_request_id = str(uuid.uuid4()) + query_topic = mqtt_topic.get_status_query_topic_for_publish( + request_id=polling_request_id, operation_id=operation_id + ) + interval = DEFAULT_POLLING_INTERVAL + query_response = None + while True: + await asyncio.sleep(interval) + # Create request with existing request id + # It is either a new request or a re-triable request + request = await self._request_ledger.create_request(polling_request_id) + try: + # Send the request to DPS, this can be a register or a query request + try: + logger.debug( + "Sending polling request to Device Provisioning Service... (rid: {})".format( + request.request_id + ) + ) + await self._mqtt_client.publish(query_topic, " ") + except asyncio.CancelledError: + logger.warning( + "Attempt to send polling request to Device Provisioning Service was cancelled while in flight. " + "It may or may not have been received (rid: {})".format(request.request_id) + ) + raise + except Exception: + logger.error( + "Sending polling request to Device Provisioning Service failed (rid: {})".format( + request.request_id + ) + ) + raise + + # Wait for a response from IoTHub + try: + logger.debug( + "Waiting to receive a response for polling request from Device Provisioning Service... (rid: {})".format( + request.request_id + ) + ) + # response = await request.get_response() + query_response = await asyncio.wait_for( + request.get_response(), DEFAULT_TIMEOUT_INTERVAL + ) + except asyncio.TimeoutError as te: + logger.debug( + "Attempt to send polling request to Device Provisioning Service " + "took more time than allowable limit while waiting for a response. If the response arrives, " + "it will be discarded (rid: {})".format(request.request_id) + ) + raise exc.ProvisioningServiceError( + "Device Provisioning Service timed out while waiting for response to the " + "polling request with (rid: {})".format(request.request_id) + ) from te + except asyncio.CancelledError: + logger.debug( + "Attempt to send polling request to Device Provisioning Service " + "was cancelled while waiting for a response. If the response arrives, " + "it will be discarded (rid: {})".format(request.request_id) + ) + raise + finally: + # If an exception caused exit before a pending request could be matched with a response + # then manually delete to prevent leaks. + if request.request_id in self._request_ledger: + await self._request_ledger.delete_request(request.request_id) + if query_response: + if 300 <= query_response.status < 429: + # breaking from while + raise exc.ProvisioningServiceError( + "Device Provisioning Service responded to the polling request with a failed status - {}. The detailed error is {}. ".format( + query_response.status, query_response.body + ) + ) + elif query_response.status >= 429: + # Process same request for retry again + if query_response.properties is not None: + retry_after = int(query_response.properties.get("retry-after", "0")) + logger.debug( + "Retrying polling request after {} secs to Device Provisioning Service...(rid: {})".format( + retry_after, request.request_id + ) + ) + interval = retry_after + else: # happens when response.status < 300 + logger.debug( + "Received response for polling request from Device Provisioning Service " + "(rid: {})".format(request.request_id) + ) + decoded_dps_response = json.loads(query_response.body) + operation_id = decoded_dps_response.get("operationId", None) + registration_status = decoded_dps_response.get("status", None) + if registration_status == "assigning": + if query_response.properties is not None: + interval = int( + query_response.properties.get( + "retry-after", DEFAULT_POLLING_INTERVAL + ) + ) + logger.debug( + "Retrying polling request after {} secs to Device Provisioning Service...(rid: {})".format( + interval, request.request_id + ) + ) + elif ( + registration_status == "assigned" or registration_status == "failed" + ): # breaking from while + decoded_dps_state = decoded_dps_response.get("registrationState", None) + registration_state: RegistrationState = { + "deviceId": decoded_dps_state.get("deviceId", None), + "assignedHub": decoded_dps_state.get("assignedHub", None), + "subStatus": decoded_dps_state.get("subStatus", None), + "createdDateTimeUtc": decoded_dps_state.get("createdDateTimeUtc", None), + "lastUpdatedDateTimeUtc": decoded_dps_state.get( + "lastUpdatedDateTimeUtc", None + ), + "etag": decoded_dps_state.get("etag", None), + "payload": decoded_dps_state.get("payload", None), + } + registration_result: RegistrationResult = { + "operationId": operation_id, + "status": registration_status, + "registrationState": registration_state, + } + return registration_result + else: + raise exc.ProvisioningServiceError( + "Device Provisioning Service responded to the polling request with an invalid " + "registration status {} failed status - {}. The entire error response is {}".format( + registration_status, + query_response.status, + json.loads(query_response.body), + ) + ) + + @property + def connected(self) -> bool: + """Boolean indicating connection status""" + return self._mqtt_client.is_connected() + + +def _create_mqtt_client( + client_id: str, client_config: config.ProvisioningClientConfig +) -> mqtt.MQTTClient: + logger.debug("Creating MQTTClient") + + logger.debug("Using {} as hostname".format(client_config.hostname)) + logger.debug("Using IoTHub Device Registration Id. Client ID is {}".format(client_id)) + + if client_config.websockets: + logger.debug("Using MQTT over websockets") + transport = "websockets" + port = 443 + websockets_path = "/$iothub/websocket" + else: + logger.debug("Using MQTT over TCP") + transport = "tcp" + port = 8883 + websockets_path = None + + client = mqtt.MQTTClient( + client_id=client_id, + hostname=client_config.hostname, + port=port, + transport=transport, + keep_alive=client_config.keep_alive, + auto_reconnect=client_config.auto_reconnect, + reconnect_interval=DEFAULT_RECONNECT_INTERVAL, + ssl_context=client_config.ssl_context, + websockets_path=websockets_path, + proxy_options=client_config.proxy_options, + ) + + return client + + +def _format_username(id_scope: str, registration_id: str) -> str: + query_param_seq = [] + + # Apply query parameters (i.e. key1=value1&key2=value2...&keyN=valueN format) + + query_param_seq.append(("api-version", constant.PROVISIONING_API_VERSION)) + query_param_seq.append(("ClientVersion", user_agent.get_provisioning_user_agent())) + + username = "{id_scope}/registrations/{registration_id}/{query_params}".format( + id_scope=id_scope, + registration_id=registration_id, + query_params=urllib.parse.urlencode(query_param_seq, quote_via=urllib.parse.quote), + ) + return username diff --git a/azure-iot-device/azure/iot/device/provisioning_session.py b/azure-iot-device/azure/iot/device/provisioning_session.py new file mode 100644 index 000000000..589e88317 --- /dev/null +++ b/azure-iot-device/azure/iot/device/provisioning_session.py @@ -0,0 +1,244 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import asyncio +import ssl +from typing import Optional, Type, Awaitable, TypeVar +from types import TracebackType + +from . import exceptions as exc +from . import signing_mechanism as sm +from . import sastoken as st +from . import config, custom_typing, constant +from . import provisioning_mqtt_client as mqtt + +_T = TypeVar("_T") + + +class ProvisioningSession: + def __init__( + self, + *, + provisioning_endpoint: str = constant.PROVISIONING_GLOBAL_ENDPOINT, + id_scope: str, + registration_id: str, + ssl_context: Optional[ssl.SSLContext] = None, + shared_access_key: Optional[str] = None, + sastoken_fn: Optional[custom_typing.FunctionOrCoroutine] = None, + sastoken_ttl: int = 3600, + **kwargs, + ) -> None: + """ + :param str provisioning_endpoint: The provisioning endpoint you wish to provision with. + If not provided, defaults to 'global.azure-devices-provisioning.net' + :param str id_scope: The ID scope used to uniquely identify the specific provisioning + service instance to register devices with. + :param str registration_id: The device registration identity being provisioned. + :param ssl_context: Custom SSL context to be used when establishing a connection. + If not provided, a default one will be used + :type ssl_context: :class:`ssl.SSLContext` + :param str shared_access_key: A key that can be used to generate SAS Tokens + :param sastoken_fn: A function or coroutine function that takes no arguments and returns + a SAS token string when invoked + :param sastoken_ttl: Time-to-live (in seconds) for SAS tokens generated when using + 'shared_access_key' authentication. + If using this auth type, a new Session will need to be created once this time expires. + Default is 3600 seconds (1 hour). + + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: ValueError if an invalid combination of parameters are provided + :raises: ValueError if an invalid 'symmetric_key' is provided + :raises: TypeError if an invalid keyword argument is provided + """ + # The following validation is present in the previous SDK. + _validate_registration_id(registration_id) + # Validate parameters + _validate_kwargs(**kwargs) + if shared_access_key and sastoken_fn: + raise ValueError( + "Incompatible authentication - cannot provide both 'shared_access_key' and 'sastoken_fn'" + ) + if not shared_access_key and not sastoken_fn and not ssl_context: + raise ValueError( + "Missing authentication - must provide one of 'shared_access_key', 'sastoken_fn' or 'ssl_context'" + ) + + # Set up SAS auth (if using) + generator: Optional[st.SasTokenGenerator] + # NOTE: Need to keep a reference to the SasTokenProvider so we can stop it during cleanup + self._sastoken_provider: Optional[st.SasTokenProvider] + if shared_access_key: + uri = _format_sas_uri(id_scope=id_scope, registration_id=registration_id) + signing_mechanism = sm.SymmetricKeySigningMechanism(shared_access_key) + generator = st.InternalSasTokenGenerator( + signing_mechanism=signing_mechanism, uri=uri, ttl=sastoken_ttl + ) + self._sastoken_provider = st.SasTokenProvider(generator) + elif sastoken_fn: + generator = st.ExternalSasTokenGenerator(sastoken_fn) + self._sastoken_provider = st.SasTokenProvider(generator) + else: + self._sastoken_provider = None + + # Create a default SSLContext if not provided + if not ssl_context: + ssl_context = _default_ssl_context() + + # Instantiate the MQTTClient + client_config = config.ProvisioningClientConfig( + hostname=provisioning_endpoint, + registration_id=registration_id, + id_scope=id_scope, + sastoken_provider=self._sastoken_provider, + ssl_context=ssl_context, + auto_reconnect=False, # No reconnect for now + **kwargs, + ) + self._mqtt_client = mqtt.ProvisioningMQTTClient(client_config) + + # This task is used to propagate dropped connections through receiver generators + # It will be set upon context manager entry and cleared upon exit + # NOTE: If we wanted to design lower levels of the stack to be specific to our + # Session design pattern, this could happen lower (and it would be simpler), but it's + # up here so we can be more implementation-generic down the stack. + self._wait_for_disconnect_task: Optional[ + asyncio.Task[Optional[exc.MQTTConnectionDroppedError]] + ] = None + + async def __aenter__(self) -> "ProvisioningSession": + # First, if using SAS auth, start up the provider + if self._sastoken_provider: + # NOTE: No try/except block is needed here because in the case of failure there is not + # yet anything that we would need to clean up. + await self._sastoken_provider.start() + + # Start/connect + try: + await self._mqtt_client.start() + await self._mqtt_client.connect() + except (Exception, asyncio.CancelledError): + # Stop/cleanup if something goes wrong + await self._stop_all() + raise + self._wait_for_disconnect_task = asyncio.create_task( + self._mqtt_client.wait_for_disconnect() + ) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + traceback: TracebackType, + ) -> None: + try: + await self._mqtt_client.disconnect() + finally: + # TODO: is it dangerous to cancel / remove this task? + if self._wait_for_disconnect_task: + self._wait_for_disconnect_task.cancel() + self._wait_for_disconnect_task = None + await self._stop_all() + + async def _stop_all(self) -> None: + try: + await self._mqtt_client.stop() + finally: + if self._sastoken_provider: + await self._sastoken_provider.stop() + + async def register( + self, payload: custom_typing.JSONSerializable = None + ) -> custom_typing.RegistrationResult: + """Register the device + + :param payload: The JSON serializable data that constitutes the registration payload + :type payload: dict, list, str, int, float, bool, None + + :returns: RegistrationResult + :rtype: RegistrationResult + + :raises: ProvisioningError if a error response is received from IoTHub + :raises: MQTTError if there is an error sending the request + :raises: CancelledError if enabling responses from IoT Hub is cancelled by network failure + """ + if not self._mqtt_client.connected: + # NOTE: We need to raise an error directly if not connected because at MQTT + # Quality of Service (QoS) level 1, used at the lower levels of this stack, + # a MQTT Publish does not actually fail if not connected - instead, it waits + # for a connection to be established, and publishes the data once connected. + # This is not desirable behavior, so we check the connection state. + raise exc.SessionError("ProvisioningSession not connected") + return await self._add_disconnect_interrupt_to_coroutine( + self._mqtt_client.send_register(payload) + ) + + def _add_disconnect_interrupt_to_coroutine(self, coro: Awaitable[_T]) -> Awaitable[_T]: + """Wrap a coroutine in another coroutine that will either return the result of the original + coroutine, or raise error in the event of disconnect + """ + + async def wrapping_coroutine(): + original_task = asyncio.create_task(coro) + done, _ = await asyncio.wait( + [original_task, self._wait_for_disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if self._wait_for_disconnect_task in done: + original_task.cancel() + cause = self._wait_for_disconnect_task.result() + if cause is not None: + raise cause + else: + raise asyncio.CancelledError("Cancelled by disconnect") + else: + return await original_task + + return wrapping_coroutine() + + +def _validate_kwargs(exclude=[], **kwargs) -> None: + """Helper function to validate user provided kwargs. + Raises TypeError if an invalid option has been provided""" + valid_kwargs = [ + # "auto_reconnect", + "keep_alive", + "proxy_options", + "websockets", + ] + + for kwarg in kwargs: + if (kwarg not in valid_kwargs) or (kwarg in exclude): + # NOTE: TypeError is the conventional error that is returned when an invalid kwarg is + # supplied. It feels like it should be a ValueError, but it's not. + raise TypeError("Unsupported keyword argument: '{}'".format(kwarg)) + + +def _validate_registration_id(reg_id: str): + if not (reg_id and reg_id.strip()): + raise ValueError("Registration Id can not be none, empty or blank.") + + +def _format_sas_uri(id_scope: str, registration_id: str) -> str: + """Format the SAS URI DPS""" + return "{id_scope}/registrations/{registration_id}".format( + id_scope=id_scope, registration_id=registration_id + ) + + +def _default_ssl_context() -> ssl.SSLContext: + """Return a default SSLContext""" + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + ssl_context.load_default_certs() + return ssl_context diff --git a/azure-iot-device/azure/iot/device/request_response.py b/azure-iot-device/azure/iot/device/request_response.py new file mode 100644 index 000000000..7dd7fa7c0 --- /dev/null +++ b/azure-iot-device/azure/iot/device/request_response.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""Infrastructure for use implementing a high-level async request/response paradigm""" +import asyncio +import uuid +from typing import Dict, Optional + + +class Response: + def __init__( + self, request_id: str, status: int, body: str, properties: Optional[Dict[str, str]] = None + ) -> None: + self.request_id = request_id + self.status = status + self.body = body # TODO: naming - "result"? + self.properties = properties + + +class Request: + def __init__(self, request_id: Optional[str] = None) -> None: + if request_id: + self.request_id = request_id + else: + self.request_id = str(uuid.uuid4()) + self.response_future: asyncio.Future[Response] = asyncio.Future() + + async def get_response(self) -> Response: + return await self.response_future + + +class RequestLedger: + def __init__(self) -> None: + self.lock = asyncio.Lock() + self.pending: Dict[str, asyncio.Future[Response]] = {} + + def __len__(self) -> int: + return len(self.pending) + + def __contains__(self, request_id): + return request_id in self.pending + + async def create_request(self, request_id: Optional[str] = None) -> Request: + request = Request(request_id=request_id) + async with self.lock: + if request.request_id not in self.pending: + self.pending[request.request_id] = request.response_future + else: + raise ValueError("Provided request_id is a duplicate") + return request + + async def delete_request(self, request_id) -> None: + async with self.lock: + del self.pending[request_id] + + async def match_response(self, response: Response) -> None: + async with self.lock: + self.pending[response.request_id].set_result(response) + del self.pending[response.request_id] diff --git a/azure-iot-device/azure/iot/device/sastoken.py b/azure-iot-device/azure/iot/device/sastoken.py new file mode 100644 index 000000000..306f1e619 --- /dev/null +++ b/azure-iot-device/azure/iot/device/sastoken.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module contains tools for working with Shared Access Signature (SAS) Tokens""" + +import abc +import asyncio +import logging +import time +import urllib.parse +from typing import Dict, List, Optional, Awaitable, Callable, cast +from .custom_typing import FunctionOrCoroutine +from .signing_mechanism import SigningMechanism + + +logger = logging.getLogger(__name__) + +DEFAULT_TOKEN_UPDATE_MARGIN: int = 120 +REQUIRED_SASTOKEN_FIELDS: List[str] = ["sr", "sig", "se"] +TOKEN_FORMAT: str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}" + + +class SasTokenError(Exception): + """Error in SasToken""" + + pass + + +class SasToken: + def __init__(self, sastoken_str: str) -> None: + """Create a SasToken object from a SAS Token string + :param str sastoken_str: The SAS Token string + + :raises: ValueError if SAS Token string is invalid + """ + self._token_str: str = sastoken_str + self._token_info: Dict[str, str] = _get_sastoken_info_from_string(sastoken_str) + + def __str__(self) -> str: + return self._token_str + + @property + def expiry_time(self) -> float: + # NOTE: Time is typically expressed in float in Python, even though a + # SAS Token expiry time should be a whole number. + return float(self._token_info["se"]) + + @property + def resource_uri(self) -> str: + uri = self._token_info["sr"] + return urllib.parse.unquote(uri) + + @property + def signature(self) -> str: + signature = self._token_info["sig"] + return urllib.parse.unquote(signature) + + +class SasTokenGenerator(abc.ABC): + @abc.abstractmethod + async def generate_sastoken(self): + pass + + +# TODO: SigningMechanismSasTokenGenerator? +class InternalSasTokenGenerator(SasTokenGenerator): + def __init__(self, signing_mechanism: SigningMechanism, uri: str, ttl: int = 3600) -> None: + """An object that can generate SasTokens using provided values + + :param str uri: The URI of the resource you are generating a tokens to access + :param signing_mechanism: The signing mechanism that will be used to sign data + :type signing mechanism: :class:`SigningMechanism` + :param int ttl: Time to live for generated tokens, in seconds (default 3600) + """ + self.signing_mechanism = signing_mechanism + self.uri = uri + self.ttl = ttl + + async def generate_sastoken(self) -> SasToken: + """Generate a new SasToken + + :raises: SasTokenError if the token cannot be generated + """ + expiry_time = int(time.time()) + self.ttl + url_encoded_uri = urllib.parse.quote(self.uri, safe="") + message = url_encoded_uri + "\n" + str(expiry_time) + try: + signature = await self.signing_mechanism.sign(message) + except Exception as e: + # Because of variant signing mechanisms, we don't know what error might be raised. + # So we catch all of them. + raise SasTokenError("Unable to generate SasToken") from e + url_encoded_signature = urllib.parse.quote(signature, safe="") + token_str = TOKEN_FORMAT.format( + resource=url_encoded_uri, + signature=url_encoded_signature, + expiry=str(expiry_time), + ) + return SasToken(token_str) + + +class ExternalSasTokenGenerator(SasTokenGenerator): + # TODO: need more specificity in generator_fn + def __init__(self, generator_fn: FunctionOrCoroutine): + """An object that can generate SasTokens by invoking a provided callable. + This callable can be a function or a coroutine function. + + :param generator_fn: A callable that takes no arguments and returns a SAS Token string + :type generator_fn: Function or Coroutine Function which returns a string + """ + self.generator_fn = generator_fn + + async def generate_sastoken(self) -> SasToken: + """Generate a new SasToken + + :raises: SasTokenError if the token cannot be generated + """ + try: + # NOTE: the typechecker has some problems here, so we help it with a cast. + if asyncio.iscoroutinefunction(self.generator_fn): + generator_fn = cast(Callable[[], Awaitable[str]], self.generator_fn) + token_str = await generator_fn() + else: + generator_coro_fn = cast(Callable[[], str], self.generator_fn) + token_str = generator_coro_fn() + return SasToken(token_str) + except Exception as e: + raise SasTokenError("Unable to generate SasToken") from e + + +class SasTokenProvider: + def __init__(self, generator: SasTokenGenerator) -> None: + """Object responsible for providing a valid SasToken. + + :param generator: A SasTokenGenerator to generate SasTokens with + :type generator: SasTokenGenerator + """ + # NOTE: There is no good way to invoke a coroutine from within the __init__, and since + # the the generator's .sign() method is a coroutine, that means we can't generate an + # initial token from it here. Thus, we have to take the initial token as a separate + # argument. + # However, this is inconvenient, and also prevents us from fast-failing if there's a + # problem with the generator_fn, so a factory coroutine method has been implemented. + self._event_loop = asyncio.get_running_loop() + self._generator = generator + self._token_update_margin = DEFAULT_TOKEN_UPDATE_MARGIN + self._new_sastoken_available = asyncio.Condition() + + # Will be set upon `.start()` + self._current_token: Optional[SasToken] = None + self._keep_token_fresh_bg_task: Optional[asyncio.Task[None]] = None + + async def _keep_token_fresh(self): + """Runs indefinitely and will generate a SasToken when the current one gets close to + expiration (based on the update margin) + """ + generate_time = self._current_token.expiry_time - self._token_update_margin + while True: + await _wait_until(generate_time) + try: + logger.debug("Updating SAS Token...") + self._current_token = await self._generator.generate_sastoken() + logger.debug("SAS Token update succeeded") + # TODO: validate that this is a valid token? + generate_time = self._current_token.expiry_time - self._token_update_margin + async with self._new_sastoken_available: + self._new_sastoken_available.notify_all() + except Exception: + logger.error("SAS Token renewal failed. Trying again in 10 seconds") + generate_time = time.time() + 10 + + async def start(self): + """Begin running the SasTokenProvider, ensuring that the current token is always valid""" + if not self._keep_token_fresh_bg_task: + logger.debug("Starting SasTokenProvider") + initial_token = await self._generator.generate_sastoken() + if initial_token.expiry_time < time.time(): + raise SasTokenError("Newly generated SAS Token has already expired") + self._current_token = initial_token + async with self._new_sastoken_available: + self._new_sastoken_available.notify_all() + self._keep_token_fresh_bg_task = asyncio.create_task(self._keep_token_fresh()) + else: + logger.debug("SasTokenProvider already running, no need to start") + + async def stop(self): + """Stop running the SasTokenProvider, clearing the current token. + Does nothing if already stopped. + """ + # Cancel and wait for cancellation to complete + if self._keep_token_fresh_bg_task: + logger.debug("Stopping SasTokenProvider") + self._keep_token_fresh_bg_task.cancel() + await asyncio.gather(self._keep_token_fresh_bg_task, return_exceptions=True) + self._keep_token_fresh_bg_task = None + # NOTE: There is an argument to be made that this value shouldn't be cleared, + # as the SasTokenProvider may be started again while it remains valid, but for + # now, we clear it for simplicity. + self._current_token = None + else: + logger.debug("SasTokenProvider was not running, no need to stop") + + def get_current_sastoken(self) -> SasToken: + """Return the current SasToken""" + if self._current_token: + return self._current_token + else: + raise RuntimeError("SasTokenProvider is not running") + + async def wait_for_new_sastoken(self) -> SasToken: + """Waits for a new SasToken to become available, and return it""" + async with self._new_sastoken_available: + await self._new_sastoken_available.wait() + return self.get_current_sastoken() + + +def _get_sastoken_info_from_string(sastoken_string: str) -> Dict[str, str]: + """Given a SAS Token string, return a dictionary of it's keys and values""" + pieces = sastoken_string.split("SharedAccessSignature ") + if len(pieces) != 2: + raise ValueError("Invalid SAS Token string: Not a SAS Token ") + + # Get sastoken info as dictionary + try: + # TODO: fix this typehint later, it needs some kind of cast + sastoken_info = dict(map(str.strip, sub.split("=", 1)) for sub in pieces[1].split("&")) # type: ignore + except Exception as e: + raise ValueError("Invalid SAS Token string: Incorrectly formatted") from e + + # Validate that all required fields are present + if not all(key in sastoken_info for key in REQUIRED_SASTOKEN_FIELDS): + raise ValueError("Invalid SAS Token string: Not all required fields present") + + # Warn if extraneous fields are present + if not all(key in REQUIRED_SASTOKEN_FIELDS for key in sastoken_info): + logger.warning("Unexpected fields present in SAS Token") + + return sastoken_info + + +# NOTE: Arguably, this doesn't really belong in this module, give it's lack of a specific +# relationship to SAS Tokens, and the fact that it needs to be unit-tested separately. +# These things suggest it should be more than just a convention-private helper, however +# its hard to justify making a separate module just for this function. +# This would be a candidate for some kind of misc utility module if other similar functions +# pop up over the course of development. Until then, it lives here. +async def _wait_until(when: float) -> None: + """Wait until a specific time has passed (accurate within 1 second). + + :param float when: The time to wait for, in seconds, since epoch + """ + while time.time() < when: + await asyncio.sleep(1) diff --git a/azure-iot-device/azure/iot/device/common/auth/signing_mechanism.py b/azure-iot-device/azure/iot/device/signing_mechanism.py similarity index 51% rename from azure-iot-device/azure/iot/device/common/auth/signing_mechanism.py rename to azure-iot-device/azure/iot/device/signing_mechanism.py index c0831a80a..e1e7210f0 100644 --- a/azure-iot-device/azure/iot/device/common/auth/signing_mechanism.py +++ b/azure-iot-device/azure/iot/device/signing_mechanism.py @@ -3,43 +3,48 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -"""This module defines an abstract SigningMechanism, as well as common child implementations of it -""" import abc +import base64 +import binascii import hmac import hashlib -import base64 +from typing import AnyStr + +# TODO: remove commented signatures class SigningMechanism(abc.ABC): @abc.abstractmethod - def sign(self, data_str): + async def sign(self, data_str: AnyStr) -> str: + # NOTE: This is defined as a coroutine to allow for flexibility of implementation. + # Some implementations may not require a coroutine, but others may, so we err on the side + # of a coroutine for consistent interface. pass class SymmetricKeySigningMechanism(SigningMechanism): - def __init__(self, key): + def __init__(self, key: AnyStr) -> None: """ A mechanism that signs data using a symmetric key :param key: Symmetric Key (base64 encoded) :type key: str or bytes + + :raises: ValueError if provided key is invalid """ - # Convert key to bytes - try: - key = key.encode("utf-8") - except AttributeError: - # If byte string, no need to encode - pass + # Convert key to bytes (if not already) + if isinstance(key, str): + key_bytes = key.encode("utf-8") + else: + key_bytes = key # Derives the signing key - # CT-TODO: is "signing key" the right term? try: - self._signing_key = base64.b64decode(key) - except (base64.binascii.Error): + self._signing_key = base64.b64decode(key_bytes) + except (binascii.Error): raise ValueError("Invalid Symmetric Key") - def sign(self, data_str): + async def sign(self, data_str: AnyStr) -> str: """ Sign a data string with symmetric key and the HMAC-SHA256 algorithm. @@ -48,18 +53,22 @@ def sign(self, data_str): :returns: The signed data :rtype: str + + :raises: ValueError if an invalid data string is provided """ - # Convert data_str to bytes - try: - data_str = data_str.encode("utf-8") - except AttributeError: - # If byte string, no need to encode - pass + # NOTE: This implementation doesn't take advantage of being a coroutine, but this is by + # design. See the definition of the abstract base class above. + + # Convert data_str to bytes (if not already) + if isinstance(data_str, str): + data_bytes = data_str.encode("utf-8") + else: + data_bytes = data_str # Derive signature via HMAC-SHA256 algorithm try: hmac_digest = hmac.HMAC( - key=self._signing_key, msg=data_str, digestmod=hashlib.sha256 + key=self._signing_key, msg=data_bytes, digestmod=hashlib.sha256 ).digest() signed_data = base64.b64encode(hmac_digest) except (TypeError): diff --git a/azure-iot-device/azure/iot/device/user_agent.py b/azure-iot-device/azure/iot/device/user_agent.py index a07615629..5f946474a 100644 --- a/azure-iot-device/azure/iot/device/user_agent.py +++ b/azure-iot-device/azure/iot/device/user_agent.py @@ -6,7 +6,7 @@ """This module is for creating agent strings for all clients""" import platform -from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER +from .constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER python_runtime = platform.python_version() os_type = platform.system() @@ -14,7 +14,7 @@ architecture = platform.machine() -def _get_common_user_agent(): +def _get_common_user_agent() -> str: return "({python_runtime};{os_type} {os_release};{architecture})".format( python_runtime=python_runtime, os_type=os_type, @@ -23,7 +23,7 @@ def _get_common_user_agent(): ) -def get_iothub_user_agent(): +def get_iothub_user_agent() -> str: """ Create the user agent for IotHub """ @@ -32,7 +32,7 @@ def get_iothub_user_agent(): ) -def get_provisioning_user_agent(): +def get_provisioning_user_agent() -> str: """ Create the user agent for Provisioning """ diff --git a/credscan_suppression.json b/credscan_suppression.json index d7785785b..c0e755f4b 100644 --- a/credscan_suppression.json +++ b/credscan_suppression.json @@ -10,37 +10,12 @@ "_justification": "Test containing fake passwords and keys" }, { - "file": "\\tests\\unit\\common\\auth\\test_signing_mechanism.py", + "file": "\\tests\\unit\\test_signing_mechanism.py", "_justification": "Test containing fake keys" }, { - "file": "\\tests\\unit\\common\\auth\\test_sastoken.py", - "_justification": "Test containing fake signed data" - }, - { - "file": "\\tests\\unit\\common\\test_mqtt_transport.py", - "_justification": "Test containing fake passwords" - }, - { - "file": "\\tests\\unit\\common\\test_http_transport.py", - "_justification": "Test containing fake passwords" - }, - { - "file": "\\tests\\unit\\iothub\\shared_client_tests.py", - "_justification": "Test containing fake signed data" - }, - { - "file": "\\tests\\unit\\iothub\\client_fixtures.py", - "_justification": "Test containing fake keys and fake signed data" - }, - { - "file": "\\tests\\unit\\iothub\\test_sync_clients.py", - "_justification": "Test containing fake signed data" - }, - { - "file": "\\tests\\unit\\iothub\\aio\\test_async_clients.py", + "file": "\\tests\\unit\\test_sastoken.py", "_justification": "Test containing fake signed data" } ] - -} \ No newline at end of file +} diff --git a/dev_utils/dev_utils/custom_mock.py b/dev_utils/dev_utils/custom_mock.py new file mode 100644 index 000000000..c7e65d190 --- /dev/null +++ b/dev_utils/dev_utils/custom_mock.py @@ -0,0 +1,44 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import asyncio +import mock + + +class HangingAsyncMock(mock.AsyncMock): + """Use this mock to hang on a awaitable coroutine. + Useful for testing task cancellation, or blocking an infinite loop. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.side_effect = self._do_hang + self._is_hanging = asyncio.Event() + self._stop_hanging = asyncio.Event() + + async def _do_hang(self, *args, **kwargs): + if not self._is_hanging.is_set(): + self._stop_hanging.clear() + self._is_hanging.set() + await self._stop_hanging.wait() + return self.return_value + + async def wait_for_hang(self): + await self._is_hanging.wait() + + def is_hanging(self): + return self._is_hanging.is_set() + + def stop_hanging(self): + if self._is_hanging.is_set(): + self._stop_hanging.set() + self._is_hanging.clear() + else: + raise RuntimeError("Not hanging") + + def reset_mock(self): + self._is_hanging.clear() + self._stop_hanging.clear() + super().reset_mock() diff --git a/dev_utils/dev_utils/iothub_mqtt_helper.py b/dev_utils/dev_utils/iothub_mqtt_helper.py new file mode 100644 index 000000000..7f662c68f --- /dev/null +++ b/dev_utils/dev_utils/iothub_mqtt_helper.py @@ -0,0 +1,82 @@ +# TODO: REMOVE THIS WHEN NO LONGER TESTING AT IOTHUB-MQTT LEVEL + +from azure.iot.device.config import IoTHubClientConfig +from azure.iot.device import sastoken as st +from azure.iot.device import signing_mechanism as sm +from azure.iot.device import connection_string as cs +import ssl +import logging + +logger = logging.getLogger(__name__) + + +async def create_client_config(cs_str): + connection_string = cs.ConnectionString(cs_str) + hostname = connection_string[cs.HOST_NAME] + device_id = connection_string[cs.DEVICE_ID] + module_id = connection_string.get(cs.MODULE_ID) + + generator = _create_sastoken_generator(connection_string) + sastoken_provider = await st.SasTokenProvider.create_from_generator(generator) + + ssl_context = _create_ssl_context() + + return IoTHubClientConfig( + device_id=device_id, + module_id=module_id, + hostname=hostname, + sastoken_provider=sastoken_provider, + ssl_context=ssl_context, + ) + + +def _create_sastoken_generator(connection_string, ttl=3600): + uri = _form_sas_uri( + hostname=connection_string[cs.HOST_NAME], + device_id=connection_string[cs.DEVICE_ID], + module_id=connection_string.get(cs.MODULE_ID), + ) + signing_mechanism = sm.SymmetricKeySigningMechanism(key=connection_string[cs.SHARED_ACCESS_KEY]) + sastoken_generator = st.InternalSasTokenGenerator(signing_mechanism, uri, ttl) + return sastoken_generator + + +def _form_sas_uri(hostname, device_id, module_id=None): + if module_id: + return "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=hostname, device_id=device_id, module_id=module_id + ) + else: + return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) + + +# TODO: use this logic directly in the client later +def _create_ssl_context(server_verification_cert=None, cipher=None, x509=None) -> ssl.SSLContext: + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + + if server_verification_cert: + logger.debug("Configuring SSLContext with custom server verification cert") + ssl_context.load_verify_locations(cadata=server_verification_cert) + else: + logger.debug("Configuring SSLContext with default certs") + ssl_context.load_default_certs() + + if cipher: + logger.debug("Configuring SSLContext with cipher suites") + ssl_context.set_ciphers(cipher) + else: + logger.debug("Not using cipher suites") + + if x509: + logger.debug("Configuring SSLContext with client-side X509 certificate and key") + ssl_context.load_cert_chain( + x509.certificate_file, + x509.key_file, + x509.pass_phrase, + ) + else: + logger.debug("Not using X509 certificates") + + return ssl_context diff --git a/dev_utils/dev_utils/leak_tracker.py b/dev_utils/dev_utils/leak_tracker.py index 6c38d4710..bc03dc84e 100644 --- a/dev_utils/dev_utils/leak_tracker.py +++ b/dev_utils/dev_utils/leak_tracker.py @@ -12,6 +12,9 @@ logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) +# When printing leaks, how many referrer levels do we print? +max_referrer_level = 5 + def _run_garbage_collection(): """ @@ -37,13 +40,13 @@ def get_printable_object_name(obj): """ try: if isinstance(obj, dict): - return "Dict ID={}: first-5-keys={}".format(id(obj), list(obj.keys())[:10]) + return "{} Dict: first-5-keys={}".format(id(obj), list(obj.keys())[:10]) else: - return "{}: {}".format(type(obj), str(obj)) + return "{} {}: {}".format(id(obj), type(obj), str(obj)) except TypeError: - return "Foreign object (raised TypeError): {}, ID={}".format(type(obj), id(obj)) + return "{} TypeError on {}".format(id(obj), type(obj)) except ModuleNotFoundError: - return "Foreign object (raised ModuleNotFoundError): {}, ID={}".format(type(obj), id(obj)) + return "{} ModuleNotFoundError on {}".format(id(obj), type(obj)) class TrackedObject(object): @@ -88,8 +91,10 @@ def __ne__(self, obj): def get_object(self): if self.weakref: return self.weakref() - else: + elif self.dict: return self.dict + elif self.object_id: + return [x for x in gc.get_objects() if id(x) == self.object_id][0] class TrackedModule(object): @@ -197,191 +202,68 @@ def dump_leak_report(self, leaked_objects): Dump a report on leaked objects, including a list of what referrers to each leaked object. - In order to display useful information on referrers, we need to do some ID-to-object - mapping. This is necessary because of the way the garbage collector keeps track of - references between objects. - - To explain this, if we have object `a` that refers to object `b`: - ``` - >>> class Object(object): - ... pass - ... - >>> a = Object() - >>> b = Object() - ``` - - And `a` has a reference to `b`: - ``` - >>> a.something = b - ``` - - This means that `a` has a reference to `b`, which means that `b` will not be collected - until _after_ `a` is collected. In other words. `a` is keeping `b` alive. - - You can see this by using `gc.get_referrers(b)` to see who refers to `b`. But, If you do - this, it will tell you that `a` does _not_ refer to `b`. Instead, it is `a.__dict__` - that refers to `b`. - - ``` - >>> a in gc.get_referrers(b) - False - >>> a.__dict__ in gc.get_referrers(b) - True - ``` - - This feels counterintuitive because, from your viewpoint, `a` does refer to `b`. - However, from the garage collector's viewpoint, `a` refers to `a.__dict__` and `a.__dict__` - refers to `b`. In effect, `a` does refer to `b`, but it does so indirectly through - `a.__dict__`: - - ``` - >>> a.__dict__ in gc.get_referrers(b) - True - >>> a in gc.get_referrers(a.__dict__) - True - ``` - - If, however, object `a` uses `__slots__` to refer to `b`, then object `a` will refer - to object `b` and `a.___dict__` will not exist.` - - ``` - >>> class ObjectWithSlots(object): - ... __slots__ = ["something"] - ... - >>> a = ObjectWithSlots() - >>> b = Object() - >>> a.something = b - >>> a in gc.get_referrers(b) - True - >>> a.__dict__ in gc.get_referrers(b) - Traceback (most recent call last): - File "", line 1, in - AttributeError: 'ObjectWithSlots' object has no attribute '__dict<' - ``` - - This can be complicated to keep track of. So, to dump useful information, we use - `id_to_name_map` to keep track of the relationship between `a` and `a.__dict__`. - In effect: - - ``` - id_to_name_map[id(a)] = str(a) - id_to_name_map[id(a.__dict__)] = str(a) - ``` - - With this mapping, we can show that `a` refers to `b`, even when it is `a.__dict__` that is - referring to `b`. - - Phew. + This prints leaked objects and the objects that are referring to them (keeping them + alive) underneath, up to `max_referrer_level` levels deep. + For example: + + ``` + A = Object() + A.B = Object() + A.B.C = Object() + ``` + + If C is marked as a leak, This will display + ``` + ID(C) C + ID(B) B + ID(A) A + ``` + Because `B` is keeping `C` alive, and `A` is keeping `B` alive. + """ logger.info("-----------------------------------------------") logger.error("Test failure. {} objects have leaked:".format(len(leaked_objects))) - logger.info("(Default text format is ") - - id_to_name_map = {} - - # first, map IDs for leaked objects. We display these slightly differently because it - # makes tracking inter-leak references a little easier. - for leak in leaked_objects: - id_to_name_map[leak.object_id] = leak - - # if the object has a `__dict__` attribute, then map the ID of that dictionary - # back to the object also. - if leak.get_object() and hasattr(leak.get_object(), "__dict__"): - dict_id = id(leak.get_object().__dict__) - id_to_name_map[dict_id] = leak - - # Second, go through all objects and map IDs for those (unless we've done them already). - # In this step, we add mappings for objects and their `__dict__` attributes, but we - # don't add `dict` objects yet. This is because we don't know if any `dict` is a user- - # created dictionary or if it's a `__dict__`. If it's a `__dict__`, we add it here and - # point it to the owning object. If it's just a `dict`, we add it in the last loop - # through - for obj in gc.get_objects(): - object_id = id(obj) - - if not isinstance(obj, dict): - if object_id not in id_to_name_map: - id_to_name_map[object_id] = TrackedObject(obj) - - if hasattr(obj, "__dict__"): - dict_id = id(obj.__dict__) - if dict_id not in id_to_name_map: - id_to_name_map[dict_id] = id_to_name_map[object_id] + logger.info("(Default text format is ") + logger.info("Printing to {} levels deep".format(max_referrer_level)) + + visited = set() + + all_objects = gc.get_objects() + leaked_object_ids = [x.object_id for x in leaked_objects] + + # This is the function that recursively displays leaks and the objets that refer + # to them. + def visit(object, indent): + line = f"{' ' * indent} {get_printable_object_name(object)}" + if indent > max_referrer_level: + # Stop printing at `max_referrer_level` levels deep + print(f"{line} (reached max depth)") + return + if id(object) == id(all_objects): + # all_objects has a reference to all objects. Stop if we reach it. + return + elif indent > 0 and id(object) in leaked_object_ids: + # We've hit an object at the top level, but we're not at the top level. + # this means one of our leaked objects is referring to another of our leaked objects. + # Stop here. + print(f"{line} (top-level leak)") + return + elif id(object) in visited: + # stop if we've previously visited this object + print(f"{line} (previously visited)") + return + elif str(type(object)) in ["", ""]: + # stop at list or list_iterator objects. There are too many of these and + # they don't provide any useful information. + return + else: + print(f"{' ' * indent} {get_printable_object_name(object)}") + visited.add(id(object)) + for referrer in gc.get_referrers(object): + visit(referrer, indent + 1) - # Third, map IDs for all dicts that we haven't done yet. - for obj in gc.get_objects(): - object_id = id(obj) - - if isinstance(obj, dict): - if object_id not in id_to_name_map: - id_to_name_map[object_id] = TrackedObject(obj) - - already_reported = set() - objects_to_report = leaked_objects.copy() - - # keep track of all 3 generations in handy local variables. These are here - # for developers who might be looking at leaks inside of pdb. - gen0 = [] - gen1 = [] - gen2 = [] - - for generation_storage, generation_name in [ - (gen0, "generation 0: objects that leaked"), - (gen1, "generation 1: objects that refer to leaked objects"), - (gen2, "generation 2: objects that refer to generation 1"), - ]: - next_set_of_objects_to_report = set() - if len(objects_to_report): - logger.info("-----------------------------------------------") - logger.info(generation_name) - - # Add our objects to our generation-specific list. This helps - # developers looking at bugs inside pdb because they can just look - # at `gen0[0].get_object()` to see the first leaked object, etc. - generation_storage.extend(objects_to_report) - - for obj in objects_to_report: - if obj in already_reported: - logger.info("already reported: {}".format(obj.object_name)) - else: - logger.info("object: {}".format(obj.object_name)) - if not obj.get_object(): - logger.info(" not recursing") - else: - for referrer in gc.get_referrers(obj.get_object()): - if ( - isinstance(referrer, dict) - and referrer.get("dict", None) == obj.get_object() - ): - # This is the dict from a TrackedObject object. Skip it. - pass - else: - object_id = id(referrer) - if object_id in id_to_name_map: - logger.info( - " referred by: {}".format(id_to_name_map[object_id]) - ) - next_set_of_objects_to_report.add(id_to_name_map[object_id]) - else: - logger.info( - " referred by Non-object: {}".format( - get_printable_object_name(referrer) - ) - ) - already_reported.add(obj) - - logger.info( - "Total: {} objects, referred to by {} objects".format( - len(objects_to_report), len(next_set_of_objects_to_report) - ) - ) - objects_to_report = next_set_of_objects_to_report + for object in leaked_objects: + visit(object.get_object(), 0) - logger.info("-----------------------------------------------") - logger.info("Leaked objects are available in local variables: gen0, gen1, and gen2") - logger.info("for the 3 generations of leaks. Use the get_object method to retrieve") - logger.info("the actual objects") - logger.info("eg: us gen0[0].get_object() to get the first leaked object") - logger.info("-----------------------------------------------") assert False, "Test failure. {} objects have leaked:".format(len(leaked_objects)) diff --git a/dev_utils/dev_utils/mqtt_helper.py b/dev_utils/dev_utils/mqtt_helper.py new file mode 100644 index 000000000..f4030579e --- /dev/null +++ b/dev_utils/dev_utils/mqtt_helper.py @@ -0,0 +1,102 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""Temporary module used to help set up transport for manual testing""" +# TODO: remove this when not testing at MQTT level +import ssl +import logging +import urllib +from azure.iot.device import connection_string as cs +from azure.iot.device import sastoken as st +from azure.iot.device import signing_mechanism as sm + +logger = logging.getLogger(__name__) + +WS_PATH = "/$iothub/websocket" +IOTHUB_API_VERSION = "2019-10-01" + + +def get_client_id(connection_string): + connection_string = cs.ConnectionString(connection_string) + return connection_string[cs.DEVICE_ID] + + +def get_hostname(connection_string): + connection_string = cs.ConnectionString(connection_string) + return connection_string[cs.HOST_NAME] + + +def get_username(connection_string): + connection_string = cs.ConnectionString(connection_string) + + query_param_seq = [] + query_param_seq.append(("api-version", IOTHUB_API_VERSION)) + # query_param_seq.append( + # ("DeviceClientType", user_agent.get_iothub_user_agent()) + # ) + + username = "{hostname}/{client_id}/?{query_params}".format( + hostname=connection_string[cs.HOST_NAME], + client_id=get_client_id(str(connection_string)), + query_params=urllib.parse.urlencode(query_param_seq, quote_via=urllib.parse.quote), + ) + return username + + +def get_password(connection_string, ttl=3600): + connection_string = cs.ConnectionString(connection_string) + uri = _form_sas_uri( + hostname=connection_string[cs.HOST_NAME], + device_id=connection_string[cs.DEVICE_ID], + module_id=connection_string.get(cs.MODULE_ID), + ) + signing_mechanism = sm.SymmetricKeySigningMechanism(key=connection_string[cs.SHARED_ACCESS_KEY]) + sastoken = st.RenewableSasToken(uri, signing_mechanism, ttl=ttl) + return str(sastoken) + + +def create_ssl_context(server_verification_cert=None, cipher=None, x509_cert=None): + """ + This method creates the SSLContext object used by Paho to authenticate the connection. + """ + logger.debug("creating a SSL context") + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + + if server_verification_cert: + logger.debug("configuring SSL context with custom server verification cert") + ssl_context.load_verify_locations(cadata=server_verification_cert) + else: + logger.debug("configuring SSL context with default certs") + ssl_context.load_default_certs() + + if cipher: + try: + logger.debug("configuring SSL context with cipher suites") + ssl_context.set_ciphers(cipher) + except ssl.SSLError as e: + # TODO: custom error with more detail? + raise e + + if x509_cert is not None: + logger.debug("configuring SSL context with client-side certificate and key") + ssl_context.load_cert_chain( + x509_cert.certificate_file, + x509_cert.key_file, + x509_cert.pass_phrase, + ) + + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + + return ssl_context + + +def _form_sas_uri(hostname, device_id, module_id=None): + if module_id: + return "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=hostname, device_id=device_id, module_id=module_id + ) + else: + return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) diff --git a/dev_utils/dev_utils/random_content.py b/dev_utils/dev_utils/random_content.py index 60d2332af..d445c375a 100644 --- a/dev_utils/dev_utils/random_content.py +++ b/dev_utils/dev_utils/random_content.py @@ -5,7 +5,7 @@ import string import json import uuid -from azure.iot.device.iothub import Message +from azure.iot.device import Message JSON_CONTENT_TYPE = "application/json" JSON_CONTENT_ENCODING = "utf-8" diff --git a/dev_utils/dev_utils/service_helper.py b/dev_utils/dev_utils/service_helper.py index 066dfdf4c..5c8a881bc 100644 --- a/dev_utils/dev_utils/service_helper.py +++ b/dev_utils/dev_utils/service_helper.py @@ -15,12 +15,15 @@ def __init__( event_loop=None, executor=None, ): - self._event_loop = event_loop or asyncio.get_event_loop() + self._event_loop = event_loop or asyncio.get_running_loop() self._executor = executor or concurrent.futures.ThreadPoolExecutor() self._inner_object = ServiceHelperSync( iothub_connection_string, eventhub_connection_string, eventhub_consumer_group ) + def clear_incoming(self): + return self._inner_object.clear_incoming() + def set_identity(self, device_id, module_id): return self._inner_object.set_identity(device_id, module_id) @@ -47,6 +50,9 @@ async def invoke_method( response_timeout_in_seconds, ) + async def get_twin(self): + return await self._event_loop.run_in_executor(self._executor, self._inner_object.get_twin) + async def send_c2d( self, payload, diff --git a/dev_utils/dev_utils/service_helper_sync.py b/dev_utils/dev_utils/service_helper_sync.py index 6861b8dc7..f313048a3 100644 --- a/dev_utils/dev_utils/service_helper_sync.py +++ b/dev_utils/dev_utils/service_helper_sync.py @@ -99,6 +99,14 @@ def __init__( self.cv = threading.Condition() self.incoming_eventhub_events = {} + def clear_incoming(self): + """Flush the incoming queues""" + with self.incoming_patch_queue.mutex: + self.incoming_patch_queue.queue.clear() + + with self.cv: + self.incoming_eventhub_events.clear() + def set_identity(self, device_id, module_id): if device_id != self.device_id or module_id != self.module_id: self.device_id = device_id @@ -150,6 +158,9 @@ def invoke_method( return response + def get_twin(self): + return self._registry_manager.get_twin(self.device_id) + def send_c2d(self, payload, properties): if self.module_id: raise TypeError("sending C2D to modules is not supported") diff --git a/devbox_setup.md b/devbox_setup.md index 36d3c178a..efd574416 100644 --- a/devbox_setup.md +++ b/devbox_setup.md @@ -20,7 +20,7 @@ This will install not only relevant development and test dependencies, but also It is recommended to use [virtualenvwrapper](https://virtualenvwrapper.readthedocs.io/en/latest/install.html) for Unix-based platforms or [virtualenvwrapper-win](https://github.com/davidmarble/virtualenvwrapper-win) for Windows, in order to easily manage custom environments and switch Python versions, however this is optional. -## Environment Variables (Optional) +## Sample Environment Variables (Optional) If you wish to follow the samples exactly as written, you will need to set some environment variables on your system. These are not required however - if you wish to use different environment variables, or no environment variables at all, simply change the samples to retrieve these values from elsewhere. Additionally, different samples use different variables, so you would only need the ones relevant to samples you intend to use. @@ -32,4 +32,19 @@ If you wish to follow the samples exactly as written, you will need to set some * **X509_KEY_FILE**: The path to the X509 key * **X509_PASS_PHRASE**: The pass phrase for the X509 key (Only necessary if cert has a password) -**This is an incomplete list of environment variables** \ No newline at end of file +**This is an incomplete list of environment variables** + + +## E2E Testing Setup (Optional - SDK Developer) + +If you wish to run end to end tests locally, you'll need to configure some additional environment variables: + +* **IOTHUB_CONNECTION_STRING**: The connection string for your IoTHub (ideally iothubowner permissions) +* **EVENTHUB_CONNECTION_STRING**: The built-in Event Hub compatible endpoint of the above IoTHub + +**NOTE**: if you wish to use dedicated E2E resources, you may also prefix the above variables with `IOTHUB_E2E_` + +Additionally, you will need to add a messaging route with the following settings to the IoTHub in order for all tests to run correctly: +* Name: twin +* Endpoint: events +* Data Source: Device Twin Change Events diff --git a/migration_guide.md b/migration_guide.md deleted file mode 100644 index 5af5328e1..000000000 --- a/migration_guide.md +++ /dev/null @@ -1,144 +0,0 @@ -# IoTHub Python SDK Migration Guide - -This guide details the migration plan to move from the IoTHub Python v1 code base to the new and improved v2 -code base. Note that this guide assumes the use of asynchronous code. - -## Installing the IoTHub Python SDK - -- v1 - -```Shell -pip install azure-iothub-device-client - -``` - -- v2 - -```Shell -pip install azure-iot-device -``` - -## Creating a device client - -When creating a device client on the V1 client the protocol was specified on in the constructor. With the v2 SDK we are -currently only supporting the MQTT protocol so it only requires to supply the connection string when you create the client. - -### Symmetric Key authentication - -- v1 - -```Python - from iothub_client import IoTHubClient, IoTHubClientError, IoTHubTransportProvider, IoTHubClientResult - from iothub_client import IoTHubMessage, IoTHubMessageDispositionResult, IoTHubError, DeviceMethodReturnValue - - client = IoTHubClient(connection_string, IoTHubTransportProvider.MQTT) -``` - -- v2 - -```Python - from azure.iot.device.aio import IoTHubDeviceClient - from azure.iot.device import Message - - client = IoTHubDeviceClient.create_from_connection_string(connection_string) - await device_client.connect() -``` - -### x.509 authentication - -For x.509 device the v1 SDK required the user to supply the certificates in a call to set_options. Moving forward in the v2 -SDK, we only require for the user to call the create function with an x.509 object containing the path to the x.509 file and -key file with the optional pass phrase if necessary. - -- v1 - -```Python - from iothub_client import IoTHubClient, IoTHubClientError, IoTHubTransportProvider, IoTHubClientResult - from iothub_client import IoTHubMessage, IoTHubMessageDispositionResult, IoTHubError, DeviceMethodReturnValue - - client = IoTHubClient(connection_string, IoTHubTransportProvider.MQTT) - # Get the x.509 certificate information - client.set_option("x509certificate", X509_CERTIFICATE) - client.set_option("x509privatekey", X509_PRIVATEKEY) -``` - -- v2 - -```Python - from azure.iot.device.aio import IoTHubDeviceClient - from azure.iot.device import Message - - # Get the x.509 certificate path from the environment - x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE") - ) - client = IoTHubDeviceClient.create_from_x509_certificate(hostname=hostname, device_id=device_id, x509=x509) - await device_client.connect() -``` - -## Sending Telemetry to IoTHub - -- v1 - -```Python - # create the device client - - message = IoTHubMessage("telemetry message") - message.message_id = "message id" - message.correlation_id = "correlation-id" - - prop_map = message.properties() - prop_map.add("property", "property_value") - client.send_event_async(message, send_confirmation_callback, user_ctx) -``` - -- v2 - -```Python - # create the device client - - message = Message("telemetry message") - message.message_id = "message id" - message.correlation_id = "correlation id" - - message.custom_properties["property"] = "property_value" - await client.send_message(message) -``` - -## Receiving a Message from IoTHub - -- v1 - -```Python - # create the device client - - def receive_message_callback(message, counter): - global RECEIVE_CALLBACKS - message = message.get_bytearray() - size = len(message_buffer) - print ( "the data in the message received was : <<<%s>>> & Size=%d" % (message_buffer[:size].decode('utf-8'), size) ) - map_properties = message.properties() - key_value_pair = map_properties.get_internals() - print ( "custom properties are: %s" % key_value_pair ) - return IoTHubMessageDispositionResult.ACCEPTED - - client.set_message_callback(message_listener_callback, RECEIVE_CONTEXT) -``` - -- v2 - -```Python - # create the device client - - # define behavior for receiving a message - def message_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - - # set the message handler on the client - client.on_message_received = message_handler -``` diff --git a/migration_guide_iothub.md b/migration_guide_iothub.md new file mode 100644 index 000000000..c2b8c7b8e --- /dev/null +++ b/migration_guide_iothub.md @@ -0,0 +1,521 @@ +# Azure IoT Device SDK for Python Migration Guide - IoTHubDeviceClient/IoTHubModuleClient -> IoTHubSession + +This guide details how to update existing code for IoT Hub that uses an `azure-iot-device` V2 release to use a V3 release instead. + +**Note that currently V3 only presents an async set of APIs. This guide will be updated when that changes** + +For changes when using the Device Provisioning Service, please refer to `migration_guide_provisioning.md` in this same directory. + +The design goals for V3 were to make a more stripped back, simple API surface that allows for a greater flexibility for the end user, as well as improved reliability and clarity. We have attempted to remove as much implicit behavior as possible in order to give full control of functionality to the end user. Additionally, we have attempted to make the experience of using the API simpler to address common pitfalls, and make applications easier to write. + +## Connection Management +The most significant change in V3 is the removal of manual connection/disconnection. Connections are now managed automatically by a context manager. + +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient + +async def main(): + client = IoTHubDeviceClient.create_from_connection_string("") + await client.connect() + # + await client.disconnect() +``` + +#### V3 +```python +from azure.iot.device import IoTHubSession + +async def main(): + async with IoTHubSession.from_connection_string("") as session: + # +``` + +When the context manager is entered, a connection will be established before running the block of code inside the context manager. After the block is done executing, a disconnection will occur upon context manager exit. You can consider that all code within the block is written with the expectation of a connection, as the context manager +represents a connection to the IoT Hub. + + +## Outgoing Operations +Initiating an operation works similarly to before, but now must be done within the block of the Session context manager. APIs will fail during invocation if they are not called from within the context manager. In the following example, we send a telemetry message, but the structure applies to any kind of operation + +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient + +async def main(): + client = IoTHubDeviceClient.create_from_connection_string("") + await client.connect() + + await client.send_message("hello world") + + await client.disconnect() +``` + +#### V3 +```python +from azure.iot.device import IoTHubSession + +async def main(): + async with IoTHubSession.from_connection_string("") as session: + await session.send_message("hello world") +``` + +Some of the APIs for operations have been changed. The following table illustrates how to use each operation from the V2 SDK in V3. + +| Operation Type | IoTHubDeviceClient API (V2) | IoTHubModuleClient API (V2) | IoTHubSession API (V3) | +|-----------------------------|-------------------------------------|-------------------------------------|----------------------------------| +| Send Telemetry Message | `.send_message()` | `.send_message()` | `.send_message()` | +| Send Routed Message | **N/A** | `.send_message_to_output()` | **NOT YET AVAILABLE** | +| Send Direct Method Response | `.send_method_response()` | `.send_method_response()` | `.send_direct_method_response()` | +| Update Reported Properties | `.patch_twin_reported_properties()` | `.patch_twin_reported_properties()` | `.update_reported_properties()` | +| Get Twin | `.get_twin()` | `.get_twin()` | `.get_twin()` | +| Get Blob Storage Info | `.get_storage_info_for_blob()` | **N/A** | **NOT YET AVAILABLE** | +| Notify Blob Upload Status | `.notify_blob_upload_status()` | **N/A** | **NOT YET AVAILABLE** | +| Invoke Direct Method | **N/A** | `.invoke_method()` | **NOT YET AVAILABLE** | + + +## Incoming Data +Incoming data receives are now implemented with a context manager and asynchronous iterator rather than using callbacks. In the following example we use incoming IoT Hub messages, but the syntax applies to any kind of received data. + +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient + +async def main(): + client = IoTHubDeviceClient.create_from_connection_string("") + + # define behavior for receiving a message + def message_handler(message): + print("the data in the message received was ") + print(message.data) + print("custom properties are") + print(message.custom_properties) + + # set the message handler on the client + client.on_message_received = message_handler + + await client.connect() + + # Loop until program is terminated + while True: + await asyncio.sleep(1) +``` + +#### V3 +```python +from azure.iot.device import IoTHubSession + +async def main(): + async with IoTHubSession.from_connection_string("") as session: + async with session.messages() as messages: + async for message in messages: + print("the data in the message received was ") + print(message.payload) + print("custom properties are") + print(message.custom_properties) +``` + +Similar to managing a the connection with a context manager, in this example the `session.messages()` context manager will enable receiving IoT Hub messages upon entry, and disable receiving them upon exit, ensuring you are only receiving data when you wish to. If the outer `IoTHubSession` context manager represents the duration of a connection to an IoT Hub, then this `session.messages()` context manager represents the duration of receiving a specific data type from an IoT Hub. + +The context manager returns `messages` in this example, an asynchronous iterator, which can be iterated over as messages are received, asynchronously suspending iteration until one arrives. You can think of the code inside this loop as being the same as the code that would have previously been put inside a callback in V2. + +The following table indicates how use the various data receives from the V2 SDK in V3, as not only is the programming model different, but some of the names have changed. + +| Incoming Data Type | IoTHubDeviceClient Callback (V2) | IoTHubModuleClient Callback (V2) | IoTHubSession Context Manager (V3) | +|--------------------------|----------------------------------------------|----------------------------------------------|------------------------------------| +| C2D Messages | `.on_message_received` | **N/A** | `.messages()` | +| Input Messages | **N/A** | `.on_message_received` | **NOT YET AVAILABLE** | +| Direct Method Requests | `.on_method_request_received` | `.on_method_request_received` | `.direct_method_requests()` | +| Desired Property Updates | `.on_twin_desired_properties_patch_received` | `.on_twin_desired_properties_patch_received` | `.desired_property_updates()` | + +Note that some of the data objects themselves have also been slightly changed for V3. Refer to the sections on Message Objects and Direct Method Objects for more information. + +## Responding to Network Failure + +In the V2 IoTHubDeviceClient and IoTHubModuleClient, the default behavior was to try and re-establish connections that failed. In the V3 `IoTHubSession`, not only is this not the default behavior, but this behavior is not supported at all. In order to provide flexibility surrounding reconnect scenarios, we have changed the design to put control in the hands of the end user. No longer will there be any confusion as to the connection state - it will be directly and clearly reported. No longer are there implicit reconnect attempts that happen without user knowledge. + +To reconnect after a connection drop, simply wrap your `IoTHubSession` usage in a try/except block. All outgoing operation APIs, as well incoming data generators will raise `MQTTError` upon a lost connection, so you can catch that, and respond with a reconnect attempt + +Additionally, in the case where you cannot connect, the `IoTHubSession` context manager itself will raise `MQTTConnectionFailedError`, which you can catch and respond to with a reconnect attempt. + +In the following example, we attempt to connect and wait for incoming C2D messages. If we are connected and the connection is dropped, a reconnect attempt will be made after 5 seconds. If we attempt to reconnect and fail doing so, we will try again after 10 seconds. + +#### V3 +```python +from azure.iot.device import IoTHubSession, MQTTError, MQTTConnectionFailedError + +async def main(): + while True: + try: + async with IoTHubSession.from_connection_string("") as session: + async with session.messages() as messages: + async for message in messages: + print(message.payload) + except MQTTError: + print("Connection was lost. Trying again in 5 seconds") + await asyncio.sleep(5) + except MQTTConnectionFailedError: + print("Could not connect. Trying again in 10 seconds") + await asyncio.sleep(10) +``` + +This does result in a slight increase in complexity over V2 where all of this logic was hidden internally in the clients, but we feel as though this will end up being simpler to use, as connection loss will always result in a thrown exception, immediately identifying the problem. Furthermore, this will eliminate any clashes between user-initiated connects, and implicit reconnection attempts by making the end user the authority on controlling the connection in all respects. + +## SAS (Shared Access Signature) Authentication + +Several types of IoT Hub authentication use SAS (Shared Access Signature) tokens. In V2 these all had their own factory methods. In V3 this has been changed. Additionally, significant changes have been made regarding credential expiration. + +### Connection String +Connection string based SAS authentication functions the same as it did before, although the factory method has been renamed. Use `.from_connection_string()` with V3 instead of the old `.create_from_connection_string()` method. + +### Shared Access Key / Symmetric Key +Creating a client that uses a shared access key to authenticate is now done via the `IoTHubSession` constructor directly, instead of the old `.create_from_symmetric_key()` method. + +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_symmetric_key( + symmetric_key="", + hostname="", + device_id="" +) +``` + +#### V3 +```python +from azure.iot.device import IoTHubSession + +session = IoTHubSession( + hostname="", + device_id="", + shared_access_key="", +) +``` + +### Custom SAS Token +This feature is currently not yet fully supported on V3. It will be soon. + +### SAS Token expiration + +In the past, the V2 SDKs engaged in "SAS Token renewal", when using SAS authentication types. What this meant, was that when a SAS token being used as a credential was about to expire, the client would either generate a new one, or ask for a new one, depending on the specific type of SAS auth. Then, the client would implicitly disconnect from the IoTHub and then connect once more with the new credential. + +This was somewhat problematic, as it once again introduced implicit behavior that affected the connection, and in the case where something went wrong, there was no easy way to report it. Furthermore, even if the user had specified they wanted to manage the connection themselves, turning off both auto-connect, and auto-reconnect behaviors, these implicit reconnects still would need to occur. + +For V3 we have decided to simply not do this. When a SAS token expires, the connection will be dropped. This brings the behavior in line with how X509 certificates behave - when they expire, the connection will be dropped (raising `MQTTError`, same as any other connection loss), and you should re-instantiate the `IoTHubSession` object, just the same as you would respond to any other connection loss. + +As a result, the example from the "Network Failure" section above also handles SAS expiration. + +You can still customize the lifespan of generated SAS tokens by providing the optional `sastoken_ttl` keyword argument when instantiating an `IoTHubSession` object, either directly with the constructor, or with a factory method. + + +## X509 Certificate Authentication +X509 authentication is now provided via the new `ssl_context` keyword for the `IoTHubSession` constructor, rather than having it's own `.create_from_x509_certificate()` method. This is to allow additional flexibility for customers who wish for more control over their TLS/SSL authentication. See "TLS/SSL customization" below for more information. + +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient +from azure.iot.device import X509 + +x509 = X509( + cert_file="", + key_file="", + pass_phrase="", +) + +client = IoTHubDeviceClient.create_from_x509_certificate( + hostname="", + device_id="", + x509=x509, +) +``` + +#### V3 +```python +from azure.iot.device import IoTHubSession +import ssl + +ssl_context = ssl.create_default_context() +ssl_context.load_cert_chain( + certfile="", + keyfile="", + password="", +) + +client = IoTHubSession( + hostname="", + device_id="", + ssl_context=ssl_context, +) +``` + +Note that SSLContexts can be used with the `.from_connection_string()` factory method as well, so V3 now fully supports X509 connection strings. + +#### V3 +```python +from azure.iot.device import IoTHubSession +import ssl + +ssl_context = ssl.create_default_context() +ssl_context.load_cert_chain( + certfile="", + keyfile="", + password="", +) + +client = IoTHubSession.from_connection_string( + "", + ssl_context=ssl_context, +) +``` + +## TLS/SSL Customization +To allow users more flexibility with TLS/SSL authentication, we have added the ability to inject an `SSLContext` object into the `IoTHubSession` via the optional `ssl_context` keyword argument that is present on the constructor and factory methods. As a result, some features previously handled via client APIs are now expected to have been directly set on the injected `SSLContext`. + +By moving to a model that allows `SSLContext` injection we can allow for users to modify any aspect of their `SSLContext`, not just the ones we previously supported via API. + +### Server Verification Certificates (CA certs) +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient + +certfile = open("") +root_ca_cert = certfile.read() + +client = IoTHubDeviceClient.create_from_connection_string( + "", + server_verification_cert=root_ca_cert +) +``` + +#### V3 +```python +from azure.iot.device import IoTHubSession +import ssl + +ssl_context = ssl.create_default_context( + cafile="", +) + +client = IoTHubSession.from_connection_string( + "", + ssl_context=ssl_context, +) +``` + +### Cipher Suites +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string( + "", + cipher="" +) +``` + +#### V3 +```python +from azure.iot.device import IoTHubSession +import ssl + +ssl_context = ssl.create_default_context() +ssl_context.set_ciphers("") + +client = IoTHubSession.from_connection_string( + "", + ssl_context=ssl_context, +) +``` + +## Data object changes + +### Message Objects +Some changes have been made to the `Message` object used for sending and receiving data. +* The `.data` attribute is now called `.payload` for consistency with other objects in the API +* The `message_id` parameter is no longer part of the constructor arguments. It should be manually added as an attribute, just like all other attributes +* The payload of a received Message is now a unicode string value instead of a bytestring value. +It will be decoded according to the content encoding property sent along with the message. + +#### V2 +```python +from azure.iot.device import Message + +payload = "this is a payload" +message_id = "1234" +m = Message(data=payload, message_id=message_id) + +assert m.data == payload +assert m.message_id = message_id +``` + +#### V3 +```python +from azure.iot.device import Message + +payload = "this is a payload" +message_id = "1234" +m = Message(payload=payload) +m.message_id = message_id + +assert m.payload == payload +``` + +### Direct Method Objects + +`MethodRequest` and `MethodResponse` objects from V2 have been renamed to `DirectMethodRequest` and `DirectMethodResponse` respectively. They are otherwise identical. + +## Removed Keyword Arguments + +Some keyword arguments provided at client creation in V2 have been removed in V3 as they are no longer necessary. + +| V2 | V3 | Explanation | +|-----------------------------|------------------|----------------------------------------------------------| +| `connection_retry` | **REMOVED** | No automatic reconnect | +| `connection_retry_interval` | **REMOVED** | No automatic reconnect | +| `auto_connect` | **REMOVED** | Connection managed by `IoTHubSession` context manager | +| `ensure_desired_properties` | **REMOVED** | No more implicit twin updates | +| `gateway_hostname` | **REMOVED** | Supported via `hostname` parameter | +| `server_verification_cert` | **REMOVED** | Supported via SSL injection | +| `cipher` | **REMOVED** | Supported via SSL injection | + + + +## Managing Lifecycle of `IoTHubSession` +The above examples are fairly simple, but what about applications that do multiple things? And how can we handle a graceful exit? + +The following example from V2 demonstrates an application that receives both C2D messages and Direct Method Requests, while also sending telemetry every 5 seconds, until a `KeyboardInterrupt` +is issued. +#### V2 +```python +from azure.iot.device.aio import IoTHubDeviceClient +from azure.iot.device import MethodResponse +import asyncio +import time + +def create_client(): + client = IoTHubDeviceClient.create_from_connection_string("") + + # define behavior for receiving a message + def message_handler(message): + print("the data in the message received was ") + print(message.data) + + # define behavior for receiving direct methods + async def method_handler(method_request): + if method_request.name == "foo": + result = do_foo() + payload = {"result": result} + status = 200 + print("Completed foo") + else: + payload = {} + status = 400 + print("Unknown direct method request") + method_response = MethodResponse.create_from_method_request(method_request, status, payload) + await client.send_method_response(method_response) + + # set the incoming data handlers on the client + client.on_message_received = message_handler + client.on_method_request_received = method_handler + + return client + +async def send_telemetry(client): + while True: + # Send the current time every 5 seconds + curr_time = time.time() + print("Sending Telemetry...") + try: + await client.send_message(str(curr_time)) + except Exception: + print("Sending telemetry failed") + await asyncio.sleep(5) + +async def main(): + client = create_client() + + await client.connect() + + try: + await send_telemetry(client) + except KeyboardInterrupt: + print("User exit!") + except Exception: + print("Unexpected error") + finally: + # Shut down for graceful exit + await client.shutdown() + +if __name__ == "__main__": + asyncio.run(main()) +``` + +Here is the same application, this time written for the V3 `IoTHubSession`. + +#### V3 +```python +from azure.iot.device import IoTHubSession, DirectMethodResponse, MQTTError, MQTTConnectionFailedError +import asyncio +import time + +async def recurring_telemetry(session): + while True: + # Send the current time every 5 seconds + curr_time = time.time() + print("Sending Telemetry...") + await session.send_message(str(curr_time)) + await asyncio.sleep(5) + +async def receive_c2d_messages(session): + async with session.messages() as messages: + async for message in messages: + print("the data in the message received was ") + print(message.payload) + +async def receive_direct_method_requests(session): + async with session.direct_method_requests() as method_requests: + async for method_request in method_requests: + if method_request.name == "foo": + result = do_foo() + payload = {"result": result} + status = 200 + print("Completed foo") + else: + payload = {} + status = 400 + print("Unknown direct method request") + method_response = DirectMethodResponse.create_from_method_request(method_request, status, payload) + await session.send_direct_method_response(method_response) + +async def main(): + while True: + try: + async with IoTHubSession.from_connection_string("") as session: + await asyncio.gather( + recurring_telemetry(session), + receive_c2d_messages(session), + receive_direct_method_requests(session), + ) + except KeyboardInterrupt: + print("User exit!") + raise + except MQTTError: + print("Connection was lost. Trying again in 5 seconds") + await asyncio.sleep(5) + except MQTTConnectionFailedError: + print("Could not connect. Trying again in 10 seconds") + await asyncio.sleep(10) + except Exception: + print("Unexpected error") + + +if __name__ == "__main__": + asyncio.run(main()) +``` +Some implementation notes on this V3 sample: + +* Unlike in the V2 sample, there is no need for a `.shutdown()` method as all cleanup is handled by the `IoTHubSession` context manager. +* Reconnection logic must be directly implemented in the V3 sample, in contrast to it being automatically done in the background in V2 (as explained in the "Responding to Network Failure" section above). This will also allow for handling SAS token expiration, which was also done implicitly in V2 (as explained in the "SAS Token Expiration" section) +* Within the `IoTHubSession` context manager, the session object is passed around to various coroutines that can be run together with `asyncio.gather`.. There are other ways this could be implemented as well, the point here is to run all your logic from the block of code inside the context manager. +* When a `KeyboardInterrupt` is issued by the user, the application breaks out of the context manager, triggering cleanup. \ No newline at end of file diff --git a/migration_guide_provisioning.md b/migration_guide_provisioning.md new file mode 100644 index 000000000..979c9bc29 --- /dev/null +++ b/migration_guide_provisioning.md @@ -0,0 +1,174 @@ +# Azure IoT Device SDK for Python Migration Guide - ProvisioningDeviceClient -> ProvisioningSession + +This guide details how to update existing code for IoT Hub provisioning that uses an `azure-iot-device` V2 release to use a V3 release instead. + +**Note that currently V3 only presents an async set of APIs. This guide will be updated when that changes** + +For changes when communicating between a device and IoT Hub, please refer to `migration_guide_iothub.md` in this same directory. + +## Default usage of the Global Provisioning Endpoint +The Global Provisioning Endpoint - `global.azure-devices-provisioning.net` previously had to be manually provided via the `provisioning_host` argument to any factory method. For V3, the argument has been renamed to `provisioning_endpoint` and is now provided directly to the `ProvisioningSession` constructor. It defaults to the Global Provisioning Endpoint if not provided, so the vast majority of users can simply not provide anything. The only time the `provisioning_endpoint` argument is necessary in V3 is if your solution involves using a private endpoint. + + +## Provisioning using Shared Access Key (Symmetric Key) +Using shared access key authentication (formerly called 'symmetric key authentication') is now provided via the `shared_access_key` argument instead of the `symmetric_key` parameter. This parameter is now provided directly to the `ProvisioningSession` constructor, rather than using the `create_from_symmetric_key()` factory method. + +#### V2 +```python +from azure.iot.device.aio import ProvisioningDeviceClient + +async def main(): + provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host="global.azure-devices-provisioning.net", + registration_id="", + id_scope="", + symmetric_key="", + ) + + provisioning_device_client.provisioning_payload = "" + + result = await provisioning_device_client.register() +``` + +#### V3 +```python +from azure.iot.device import ProvisioningSession + +async def main(): + async with ProvisioningSession( + id_scope="", + registration_id="", + shared_access_key="" + ) as session: + result = await session.register(payload="") +``` + + +## Provisioning using X509 Certificates +X509 authentication is now provided via the new `ssl_context` keyword argument for the `ProvisioningSession` constructor, rather than using `.create_from_x509_certificate()` factory method. This is to allow additional flexibility for +customers who wish to have more control over their TLS/SSL authentication. See "TLS/SSL customization" below for more information. + +#### V2 +```python +from azure.iot.device.aio import ProvisioningDeviceClient +from azure.iot.device import X509 + +async def main(): + x509 = X509( + cert_file="", + key_file="", + pass_phrase="", + ) + + provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host="global.azure-devices-provisioning.net", + registration_id="", + id_scope="", + x509=x509, + ) + + provisioning_device_client.provisioning_payload = "" + + result = await provisioning_device_client.register() +``` + +#### V3 +```python +from azure.iot.device import ProvisioningSession +import ssl + +async def main(): + ssl_context = ssl.create_default_context() + ssl_context.load_cert_chain( + certfile="", + keyfile="", + password="", + ) + + async with ProvisioningSession( + id_scope="", + registration_id="", + ssl_context=ssl_context + ) as session: + result = await session.register(payload="") +``` + +## TLS/SSL Customization +To allow users more flexibility with TLS/SSL authentication, we have added the ability to inject an `SSLContext` object into the `ProvisioningSession` via the optional `ssl_context` keyword argument that is present on the constructor and factory methods. As a result, some features previously handled via client APIs are now expected to have been directly set on the injected `SSLContext`. + +By moving to a model that allows `SSLContext` injection we can allow for users to modify any aspect of their `SSLContext`, not just the ones we previously supported via API. + + +### Server Verification Certificates (CA certs) +#### V2 +```python +from azure.iot.device.aio import ProvisioningDeviceClient + +certfile = open("") +root_ca_cert = certfile.read() + +provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host="global.azure-devices-provisioning.net", + registration_id="", + id_scope="", + symmetric_key="", + server_verification_cert=root_ca_cert, +) +``` + +#### V3 +```python +from azure.iot.device import ProvisioningSession +import ssl + +ssl_context = ssl.create_default_context( + cafile="", +) + +session = ProvisioningSession( + registration_id="", + id_scope="", + symmetric_key="", + ssl_context=ssl_context, +) +``` + +### Cipher Suites +#### V2 +```python +from azure.iot.device.aio import ProvisioningDeviceClient + +provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host="global.azure-devices-provisioning.net",, + registration_id="", + id_scope="", + symmetric_key="", + cipher="", +) +``` + +#### V3 +```python +from azure.iot.device import ProvisioningSession +import ssl + +ssl_context = ssl.create_default_context() +ssl_context.set_ciphers("") + +session = ProvisioningSession( + registration_id="", + id_scope="", + symmetric_key="", + ssl_context=ssl_context, +) +``` + +## Removed Keyword Arguments + +Some keyword arguments provided at client creation in V2 have been removed in V3 as they are no longer necessary. + +| V2 | V3 | Explanation | +|-----------------------------|------------------|----------------------------------------------------------| +| `gateway_hostname` | **REMOVED** | Unsupported scenario (was unnecessary in V2) | +| `server_verification_cert` | **REMOVED** | Supported via SSL injection | +| `cipher` | **REMOVED** | Supported via SSL injection | diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..7024cfdec --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +show_error_codes = True diff --git a/pytest.ini b/pytest.ini index 07ac113a3..ee54494b7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,7 @@ [pytest] +mock_use_standalone_module = true +asyncio_mode = auto testdox_format = plaintext addopts = --testdox --timeout 20 --ignore e2e --ignore tests/e2e norecursedirs=__pycache__, *.egg-info -filterwarnings = ignore::DeprecationWarning \ No newline at end of file +filterwarnings = ignore::DeprecationWarning diff --git a/requirements_dev.txt b/requirements_dev.txt index bc392a27b..00e79e83f 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -6,3 +6,4 @@ pre-commit twine pylint rope +mypy diff --git a/requirements_test.txt b/requirements_test.txt index d7bfcfe21..18b4e3c96 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,10 +1,13 @@ pytest -pytest-mock -pytest-asyncio <= 0.16 # Can remove this once Python 3.6 support is dropped +pytest-mock>=3.10.0 +pytest-asyncio>=0.20.3 pytest-testdox>=1.1.1 pytest-cov pytest-timeout +pytest-lazy-fixture +mock # Need to use instead of builtin for backports to 3.7 flake8 +cryptography # Needed for cert generation and e2e azure-iot-hub # Only needed for iothub e2e azure-iothub-provisioningserviceclient >= 1.2.0 # Only needed for provisioning e2e azure-eventhub # Only needed for iothub e2e diff --git a/samples/README.md b/samples/README.md deleted file mode 100644 index 94ca0b56c..000000000 --- a/samples/README.md +++ /dev/null @@ -1,158 +0,0 @@ -# Samples for the Azure IoT Hub Device SDK - -This directory contains samples showing how to use the various features of the Microsoft Azure IoT Hub service from a device running the Azure IoT Hub Device SDK. - -## Quick Start - Simple Telemetry Sample (send message) - -**Note that this sample is configured for Python 3.7+.** To ensure that your Python version is up to date, run `python --version`. If you have both Python 2 and Python 3 installed (and are using a Python 3 environment for this SDK), then install all libraries using `pip3` as opposed to `pip`. This ensures that the libraries are installed to your Python 3 runtime. - -1. Install the [Azure CLI](https://docs.microsoft.com/cli/azure/install-azure-cli?view=azure-cli-latest) (or use the [Azure Cloud Shell](https://shell.azure.com/)) and use it to [create an Azure IoT Hub](https://docs.microsoft.com/cli/azure/iot/hub?view=azure-cli-latest#az_iot_hub_create). - - ```bash - az iot hub create --resource-group --name - ``` - - * Note that this operation may take a few minutes. - -2. Add the IoT Extension to the Azure CLI, and then [register a device identity](https://docs.microsoft.com/cli/azure/iot/hub/device-identity?view=azure-cli-latest#az_iot_hub_device_identity_create) - - ```bash - az extension add --name azure-iot - az iot hub device-identity create --hub-name --device-id - ``` - -3. [Retrieve your Device Connection String](https://docs.microsoft.com/cli/azure/iot/hub/device-identity/connection-string?view=azure-cli-latest#az_iot_hub_device_identity_connection_string_show) using the Azure CLI - - ```bash - az iot hub device-identity connection-string show --device-id --hub-name - ``` - - It should be in the format: - - ```Text - HostName=.azure-devices.net;DeviceId=;SharedAccessKey= - ``` - -4. [Begin monitoring for telemetry](https://docs.microsoft.com/cli/azure/iot/hub?view=azure-cli-latest#az_iot_hub_monitor_events) on your IoT Hub using the Azure CLI - - ```bash - az iot hub monitor-events --hub-name --output json - ``` - -5. On your device, set the Device Connection String as an environment variable called `IOTHUB_DEVICE_CONNECTION_STRING`. - - **Windows (cmd)** - - ```cmd - set IOTHUB_DEVICE_CONNECTION_STRING= - ``` - - * Note that there are **NO** quotation marks around the connection string. - - **Linux (bash)** - - ```bash - export IOTHUB_DEVICE_CONNECTION_STRING="" - ``` - -6. Once the Device Connection String is set, run the following code from [simple_send_message.py](simple_send_message.py) on your device from the terminal or your IDE: - - ```python - import os - import asyncio - from azure.iot.device.aio import IoTHubDeviceClient - - - async def main(): - # Fetch the connection string from an environment variable - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - # Create instance of the device client using the authentication provider - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # Connect the device client. - await device_client.connect() - - # Send a single message - print("Sending message...") - await device_client.send_message("This is a message that is being sent") - print("Message successfully sent!") - - # finally, shut down the client - await device_client.shutdown() - - - if __name__ == "__main__": - asyncio.run(main()) - ``` - -7. Check the Azure CLI output to verify that the message was received by the IoT Hub. You should see the following output: - - ```bash - Starting event monitor, use ctrl-c to stop... - event: - origin: - payload: This is a message that is being sent - ``` - -8. Your device is now able to connect to Azure IoT Hub! - - - - -## Read this if you want to run the sample using GitHub Codespaces - -You can use Github Codespaces to be up and running quickly! Here are the steps to follow. - -**1) Make sure you have the prerequisites** - -In order to run the device samples you will first need the following prerequisites: - -* An Azure IoT Hub instance. [Link if you don't.](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-create-through-portal) -* A device identity for your device. [Link if you don't.](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-create-through-portal#register-a-new-device-in-the-iot-hub) - -**2) Create and open Codespace** - -* Select the Codespaces tab and the "New codespace" button - - ![screen shot of create codespace](./media/codespace.png) - -* Once the Codespace is open, all required packages to run the samples will be setup for you - -**3) Set the DEVICE_CONNECTION_STRING environment variable** - -Set the Device Connection String as an environment variable called `IOTHUB_DEVICE_CONNECTION_STRING`. - -```bash -export IOTHUB_DEVICE_CONNECTION_STRING="" -``` - -**4) Run it** - -Run the sample using the following commands: - -```bash -cd azure-iot-device/samples -python3 simple_send_message.py -``` - -## Additional Samples - -Further samples with more complex IoT Hub scenarios are contained in the [async-hub-scenarios](async-hub-scenarios) directory, including: - -* Send multiple telemetry messages from a Device -* Receive Cloud-to-Device (C2D) messages on a Device -* Send and receive updates to device twin -* Receive direct method invocations -* Upload file into an associated Azure storage account - -Further samples with more complex IoT Edge scenarios involving IoT Edge modules and downstream devices are contained in the [async-edge-scenarios](async-edge-scenarios) directory, including: - -* Send multiple telemetry messages from a Module -* Receive input messages on a Module -* Send messages to a Module Output -* Send messages to IoT Edge from a downstream or 'leaf' device - -Samples for the synchronous clients are contained in the [sync-samples](sync-samples) directory. - -Samples for use of Azure IoT Plug and Play are contained in the [pnp](pnp) directory. diff --git a/samples/async-edge-scenarios/README.md b/samples/async-edge-scenarios/README.md deleted file mode 100644 index 59f6644ba..000000000 --- a/samples/async-edge-scenarios/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# Advanced IoT Edge Scenario Samples for the Azure IoT Hub Device SDK - -This directory contains samples showing how to use the various features of Azure IoT Hub Device SDK with Azure IoT Edge. - -**Please note** that IoT Edge solutions are scoped to Linux containers and devices, documented [here](https://docs.microsoft.com/en-us/azure/iot-edge/tutorial-python-module#solution-scope). Please see [this blog post](https://techcommunity.microsoft.com/t5/internet-of-things/linux-modules-with-azure-iot-edge-on-windows-10-iot-enterprise/ba-p/1407066) to learn more about using Linux containers for IoT Edge on Windows devices. - -**These samples are written to run in Python 3.7+**, but can be made to work with Python 3.6 with a slight modification as noted in each sample: - -```python -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() -``` - -In order to use these samples, they **must** be run from inside an Edge container. - -## Included Samples -* [receive_data.py](receive_data.py) - Receive messages, twin patches, and method requests sent to an Edge module. -* [send_message.py](send_message.py) - Send multiple telmetry messages in parallel from an Edge module to the Azure IoT Hub or Azure IoT Edge. -* [send_message_to_output.py](send_message_to_output.py) - Send multiple messages in parallel from an Edge module to a specific output -* [send_message_downstream.py](send_message_downstream.py) - Send messages from a downstream or 'leaf' device to IoT Edge diff --git a/samples/async-edge-scenarios/invoke_method_on_module.py b/samples/async-edge-scenarios/invoke_method_on_module.py deleted file mode 100644 index 15c7b5446..000000000 --- a/samples/async-edge-scenarios/invoke_method_on_module.py +++ /dev/null @@ -1,41 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# ------------------------------------------------------------------------- - -import asyncio -from azure.iot.device.aio import IoTHubModuleClient - -messages_to_send = 10 - - -async def main(): - # Inputs/Outputs are only supported in the context of Azure IoT Edge and module client - # The module client object acts as an Azure IoT Edge module and interacts with an Azure IoT Edge hub - module_client = IoTHubModuleClient.create_from_edge_environment() - - # Connect the client. - await module_client.connect() - fake_method_params = { - "methodName": "doSomethingInteresting", - "payload": "foo", - "responseTimeoutInSeconds": 5, - "connectTimeoutInSeconds": 2, - } - response = await module_client.invoke_method( - device_id="fakeDeviceId", module_id="fakeModuleId", method_params=fake_method_params - ) - print("Method Response: {}".format(response)) - - # Finally, shut down the client - await module_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-edge-scenarios/receive_data.py b/samples/async-edge-scenarios/receive_data.py deleted file mode 100644 index f084af933..000000000 --- a/samples/async-edge-scenarios/receive_data.py +++ /dev/null @@ -1,99 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import asyncio -import signal -import threading -from azure.iot.device.aio import IoTHubModuleClient -from azure.iot.device import MethodResponse - - -# Event indicating client stop -stop_event = threading.Event() - - -def create_client(): - # The client object is used to interact with your Azure IoT hub. - client = IoTHubModuleClient.create_from_edge_environment() - - # Define behavior for receiving an input message on input1 and input2 - # NOTE: this could be a coroutine or a function - def message_handler(message): - if message.input_name == "input1": - print("Message received on INPUT 1") - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - elif message.input_name == "input2": - print("Message received on INPUT 2") - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - else: - print("message received on unknown input") - - # Define behavior for receiving a twin desired properties patch - # NOTE: this could be a coroutine or function - def twin_patch_handler(patch): - print("the data in the desired properties patch was: {}".format(patch)) - - # Define behavior for receiving methods - async def method_handler(method_request): - if method_request.name == "get_data": - print("Received request for data") - method_response = MethodResponse.create_from_method_request( - method_request, 200, "some data" - ) - await client.send_method_response(method_response) - else: - print("Unknown method request received: {}".format(method_request.name)) - method_response = MethodResponse.create_from_method_request(method_request, 400, None) - await client.send_method_response(method_response) - - # set the received data handlers on the client - client.on_message_received = message_handler - client.on_twin_desired_properties_patch_received = twin_patch_handler - client.on_method_request_received = method_handler - - return client - - -async def run_sample(client): - # Customize this coroutine to do whatever tasks the module initiates - # e.g. sending messages - await client.connect() - while not stop_event.is_set(): - await asyncio.sleep(1000) - - -def main(): - # NOTE: Client is implicitly connected due to the handler being set on it - client = create_client() - - # Define a handler to cleanup when module is is terminated by Edge - def module_termination_handler(signal, frame): - print("IoTHubClient sample stopped by Edge") - stop_event.set() - - # Set the Edge termination handler - signal.signal(signal.SIGTERM, module_termination_handler) - - # Run the sample - loop = asyncio.get_event_loop() - try: - loop.run_until_complete(run_sample(client)) - except Exception as e: - print("Unexpected error %s " % e) - raise - finally: - print("Shutting down IoT Hub Client...") - loop.run_until_complete(client.shutdown()) - loop.close() - - -if __name__ == "__main__": - main() diff --git a/samples/async-edge-scenarios/send_message.py b/samples/async-edge-scenarios/send_message.py deleted file mode 100644 index b3e36a596..000000000 --- a/samples/async-edge-scenarios/send_message.py +++ /dev/null @@ -1,44 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import asyncio -import uuid -from azure.iot.device.aio import IoTHubModuleClient -from azure.iot.device import Message - -messages_to_send = 10 - - -async def main(): - # The client object is used to interact with your Azure IoT hub. - module_client = IoTHubModuleClient.create_from_edge_environment() - - # Connect the client. - await module_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - await module_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # Finally, shut down the client - await module_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-edge-scenarios/send_message_downstream.py b/samples/async-edge-scenarios/send_message_downstream.py deleted file mode 100644 index 66a0c26dc..000000000 --- a/samples/async-edge-scenarios/send_message_downstream.py +++ /dev/null @@ -1,58 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -import uuid -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message - -messages_to_send = 10 - - -async def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - # NOTE: connection string must contain ;GatewayHostName= - # make sure your IoT Edge box is setup as a 'transparent gateway' per the IOT Edge documentation - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - # path to the root ca cert used on your iot edge device (must copy the pem file to this downstream device) - # example: /home/azureuser/edge_certs/azure-iot-test-only.root.ca.cert.pem - ca_cert = os.getenv("IOTEDGE_ROOT_CA_CERT_PATH") - - certfile = open(ca_cert) - root_ca_cert = certfile.read() - - # The client object is used to interact with your Azure IoT Edge device. - device_client = IoTHubDeviceClient.create_from_connection_string( - connection_string=conn_str, server_verification_cert=root_ca_cert - ) - - # Connect the client. - await device_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - await device_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-edge-scenarios/send_message_to_output.py b/samples/async-edge-scenarios/send_message_to_output.py deleted file mode 100644 index 45bd4c223..000000000 --- a/samples/async-edge-scenarios/send_message_to_output.py +++ /dev/null @@ -1,45 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import asyncio -import uuid -from azure.iot.device.aio import IoTHubModuleClient -from azure.iot.device import Message - -messages_to_send = 10 - - -async def main(): - # Inputs/Outputs are only supported in the context of Azure IoT Edge and module client - # The module client object acts as an Azure IoT Edge module and interacts with an Azure IoT Edge hub - module_client = IoTHubModuleClient.create_from_edge_environment() - - # Connect the client. - await module_client.connect() - - # Send a filled out Message object - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - await module_client.send_message_to_output(msg, "twister") - print("done sending message #" + str(i)) - - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send)]) - - # Finally, shut down the client - await module_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-edge-scenarios/update_twin_reported_properties.py b/samples/async-edge-scenarios/update_twin_reported_properties.py deleted file mode 100644 index 72b66c1d1..000000000 --- a/samples/async-edge-scenarios/update_twin_reported_properties.py +++ /dev/null @@ -1,34 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import asyncio -import random -from azure.iot.device.aio import IoTHubModuleClient - - -async def main(): - # The client object is used to interact with your Azure IoT hub. - module_client = IoTHubModuleClient.create_from_edge_environment() - - # connect the client. - await module_client.connect() - - # update the reported properties - reported_properties = {"temperature": random.randint(320, 800) / 10} - print("Setting reported temperature to {}".format(reported_properties["temperature"])) - await module_client.patch_twin_reported_properties(reported_properties) - - # Finally, shut down the client - await module_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/README.md b/samples/async-hub-scenarios/README.md deleted file mode 100644 index 111275aac..000000000 --- a/samples/async-hub-scenarios/README.md +++ /dev/null @@ -1,94 +0,0 @@ -# Advanced IoT Hub Scenario Samples for the Azure IoT Hub Device SDK - -This directory contains samples showing how to use the various features of Azure IoT Hub Device SDK with the Azure IoT Hub. - -**These samples are written to run in Python 3.7+**, but can be made to work with Python 3.6 with a slight modification as noted in each sample: - -```python -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() -``` - -## Included Samples - -### IoTHub Samples - -In order to use these samples, you **must** set your Device Connection String in the environment variable `IOTHUB_DEVICE_CONNECTION_STRING`. - -* [send_message.py](send_message.py) - Send multiple telemetry messages in parallel from a device to the Azure IoT Hub. - * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: - - ```bash - az iot hub monitor-events --hub-name --output table - ``` - - * [recurring_telemetry.py](recurring_telemetry.py) - Send telemetry message every two seconds from a device to the Azure IoT Hub. - * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: - - ```bash - az iot hub monitor-events --hub-name --output table - ``` - -* [receive_message.py](receive_message.py) - Receive Cloud-to-Device (C2D) messages sent from the Azure IoT Hub to a device. - * In order to send a C2D message, use the following Azure CLI command: - - ```bash - az iot device c2d-message send --device-id --hub-name --data - ``` - -* [receive_direct_method.py](receive_direct_method.py) - Receive direct method requests on a device from the Azure IoT Hub and send responses back - * In order to invoke a direct method, use the following Azure CLI command: - - ```bash - az iot hub invoke-device-method --device-id --hub-name --method-name - ``` - -* [receive_twin_desired_properties_patch](receive_twin_desired_properties_patch.py) - Receive an update patch of changes made to the device twin's desired properties - * In order to send a update patch to a device twin's desired properties, use the following Azure CLI command: - - ```bash - az iot hub device-twin update --device-id --hub-name --set properties.desired.= - ``` - -* [update_twin_reported_properties](update_twin_reported_properties.py) - Send an update patch of changes to the device twin's reported properties - * You can see the changes reflected in your device twin by using the following Azure CLI command: - - ```bash - az iot hub device-twin show --device-id --hub-name - ``` - -* [upload_to_blob](upload_to_blob.py) - Upload file into the linked Azure storage account - * You must associate an Azure storage account to the IoT Hub instance. [Learn more](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-configure-file-upload) - -### DPS Samples - -#### Individual Enrollment - -In order to use these samples, you **must** have the following environment variables :- - -* PROVISIONING_HOST -* PROVISIONING_IDSCOPE -* PROVISIONING_REGISTRATION_ID - -There are 2 ways that your device can get registered to the provisioning service differing in authentication mechanisms. Depending on the mechanism used additional environment variables are needed for the samples:- - -* [provision_symmetric_key.py](provision_symmetric_key.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key, then send telemetry messages to IoTHub. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. -* [provision_symmetric_key_with_payload.py](provision_symmetric_key_with_payload.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key while supplying a custom payload, then send telemetry messages to IoTHub. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. -* [provision_x509.py](provision_x509.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key, then send a telemetry message to IoTHub. For this you must have the environment variable X509_CERT_FILE, X509_KEY_FILE, PASS_PHRASE. - - -#### Group Enrollment - -In order to use these samples, you **must** have the following environment variables :- - -* PROVISIONING_HOST -* PROVISIONING_IDSCOPE - -* [provision_symmetric_key_group.py](provision_symmetric_key_group.py) - Provision multiple devices to IoTHub by registering them to the Device Provisioning Service using derived symmetric keys, then send telemetry to IoTHub from these devices. For this you must have knowledge of the group symmetric key and must have the environment variables PROVISIONING_DEVICE_ID_1, PROVISIONING_DEVICE_ID_2, PROVISIONING_DEVICE_ID_3. - * NOTE : Group symmetric key must NEVER be stored and all the device keys must be computationally derived prior to using this sample. - diff --git a/samples/async-hub-scenarios/get_twin.py b/samples/async-hub-scenarios/get_twin.py deleted file mode 100644 index cf521044b..000000000 --- a/samples/async-hub-scenarios/get_twin.py +++ /dev/null @@ -1,34 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient - - -async def main(): - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # connect the client. - await device_client.connect() - - # get the twin - twin = await device_client.get_twin() - print("Twin document:") - print("{}".format(twin)) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/provision_symmetric_key.py b/samples/async-hub-scenarios/provision_symmetric_key.py deleted file mode 100644 index 7b99f458b..000000000 --- a/samples/async-hub-scenarios/provision_symmetric_key.py +++ /dev/null @@ -1,67 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import asyncio -from azure.iot.device.aio import ProvisioningDeviceClient -import os -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message -import uuid - - -messages_to_send = 10 -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") -registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") -symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") - - -async def main(): - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=symmetric_key, - ) - - registration_result = await provisioning_device_client.register() - - print("The complete registration result is") - print(registration_result.registration_state) - - if registration_result.status == "assigned": - print("Will send telemetry from the provisioned device") - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=symmetric_key, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - # Connect the client. - await device_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - await device_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # finally, disconnect - await device_client.disconnect() - else: - print("Can not send telemetry from the provisioned device") - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/provision_symmetric_key_group.py b/samples/async-hub-scenarios/provision_symmetric_key_group.py deleted file mode 100644 index 904625c1a..000000000 --- a/samples/async-hub-scenarios/provision_symmetric_key_group.py +++ /dev/null @@ -1,138 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -import base64 -import hmac -import hashlib -from azure.iot.device.aio import ProvisioningDeviceClient -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message -import uuid - -messages_to_send = 5 - -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") - -# These are the names of the devices that will eventually show up on the IoTHub -# Please make sure that there are no spaces in these device ids. -device_id_1 = os.getenv("PROVISIONING_DEVICE_ID_1") -device_id_2 = os.getenv("PROVISIONING_DEVICE_ID_2") -device_id_3 = os.getenv("PROVISIONING_DEVICE_ID_3") - -# For computation of device keys -device_ids_to_keys = {} - - -# NOTE : Only for illustration purposes. -# This is how a device key can be derived from the group symmetric key. -# This is just a helper function to show how it is done. -# Please don't directly store the group master key on the device. -# Follow the following method to compute the device key somewhere else. - - -def derive_device_key(device_id, group_symmetric_key): - """ - The unique device ID and the group master key should be encoded into "utf-8" - After this the encoded group master key must be used to compute an HMAC-SHA256 of the encoded registration ID. - Finally the result must be converted into Base64 format. - The device key is the "utf-8" decoding of the above result. - """ - message = device_id.encode("utf-8") - signing_key = base64.b64decode(group_symmetric_key.encode("utf-8")) - signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) - device_key_encoded = base64.b64encode(signed_hmac.digest()) - return device_key_encoded.decode("utf-8") - - -# derived_device_key has been computed already using the helper function somewhere else -# AND NOT on this sample. Do not use the direct master key on this sample to compute device key. -derived_device_key_1 = "some_value_already_computed" -derived_device_key_2 = "some_value_already_computed" -derived_device_key_3 = "some_value_already_computed" - -device_ids_to_keys[device_id_1] = derived_device_key_1 -device_ids_to_keys[device_id_2] = derived_device_key_2 -device_ids_to_keys[device_id_3] = derived_device_key_3 - - -async def send_test_message(i, client): - print("sending message # {index} for client with id {id}".format(index=i, id=client.id)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - await client.send_message(msg) - print("done sending message # {index} for client with id {id}".format(index=i, id=client.id)) - - -async def main(): - async def register_device(registration_id): - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=device_ids_to_keys[registration_id], - ) - - return await provisioning_device_client.register() - - results = await asyncio.gather( - register_device(device_id_1), register_device(device_id_2), register_device(device_id_3) - ) - - clients_to_device_ids = {} - - for index in range(0, len(results)): - registration_result = results[index] - print("The complete state of registration result is") - print(registration_result.registration_state) - - if registration_result.status == "assigned": - device_id = registration_result.registration_state.device_id - - print( - "Will send telemetry from the provisioned device with id {id}".format(id=device_id) - ) - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=device_ids_to_keys[device_id], - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - # Assign the Id just for print statements - device_client.id = device_id - - clients_to_device_ids[device_id] = device_client - - else: - print("Can not send telemetry from the provisioned device") - - # connect all the clients - await asyncio.gather(*[client.connect() for client in clients_to_device_ids.values()]) - - # send `messages_to_send` messages in parallel. - await asyncio.gather( - *[ - send_test_message(i, client) - for i, client in [ - (i, client) - for i in range(1, messages_to_send + 1) - for client in clients_to_device_ids.values() - ] - ] - ) - - # disconnect all the clients - await asyncio.gather(*[client.disconnect() for client in clients_to_device_ids.values()]) - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/provision_symmetric_key_with_payload.py b/samples/async-hub-scenarios/provision_symmetric_key_with_payload.py deleted file mode 100644 index a56a45343..000000000 --- a/samples/async-hub-scenarios/provision_symmetric_key_with_payload.py +++ /dev/null @@ -1,77 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import ProvisioningDeviceClient -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message -import uuid - -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") -registration_id = os.getenv("PROVISIONING_REGISTRATION_ID_PAYLOAD") -symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY_PAYLOAD") - -messages_to_send = 10 - - -class Fruit(object): - def __init__(self, first_name, last_name, dict_of_stuff): - self.first_name = first_name - self.last_name = last_name - self.props = dict_of_stuff - - -async def main(): - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=symmetric_key, - ) - - properties = {"Type": "Apple", "Sweet": "True"} - fruit_a = Fruit("McIntosh", "Red", properties) - provisioning_device_client.provisioning_payload = fruit_a - registration_result = await provisioning_device_client.register() - - print("The complete registration result is") - print(registration_result.registration_state) - - if registration_result.status == "assigned": - print("Will send telemetry from the provisioned device") - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=symmetric_key, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - # Connect the client. - await device_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - await device_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # finally, disconnect - await device_client.disconnect() - else: - print("Can not send telemetry from the provisioned device") - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/provision_x509.py b/samples/async-hub-scenarios/provision_x509.py deleted file mode 100644 index 5312c2392..000000000 --- a/samples/async-hub-scenarios/provision_x509.py +++ /dev/null @@ -1,73 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -import asyncio -from azure.iot.device import X509 -from azure.iot.device.aio import ProvisioningDeviceClient -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message -import uuid - - -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") -registration_id = os.getenv("DPS_X509_REGISTRATION_ID") -messages_to_send = 10 - - -async def main(): - x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), - ) - - provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - x509=x509, - ) - - registration_result = await provisioning_device_client.register() - - print("The complete registration result is") - print(registration_result.registration_state) - - if registration_result.status == "assigned": - print("Will send telemetry from the provisioned device") - device_client = IoTHubDeviceClient.create_from_x509_certificate( - x509=x509, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - - # Connect the client. - await device_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - await device_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # finally, disconnect - await device_client.disconnect() - else: - print("Can not send telemetry from the provisioned device") - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/receive_direct_method.py b/samples/async-hub-scenarios/receive_direct_method.py deleted file mode 100644 index c16f21237..000000000 --- a/samples/async-hub-scenarios/receive_direct_method.py +++ /dev/null @@ -1,71 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import MethodResponse - - -async def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # connect the client. - await device_client.connect() - - # Define behavior for handling methods - async def method_request_handler(method_request): - # Determine how to respond to the method request based on the method name - if method_request.name == "method1": - payload = {"result": True, "data": "some data"} # set response payload - status = 200 # set return status code - print("executed method1") - elif method_request.name == "method2": - payload = {"result": True, "data": 1234} # set response payload - status = 200 # set return status code - print("executed method2") - else: - payload = {"result": False, "data": "unknown method"} # set response payload - status = 400 # set return status code - print("executed unknown method: " + method_request.name) - - # Send the response - method_response = MethodResponse.create_from_method_request(method_request, status, payload) - await device_client.send_method_response(method_response) - - # Set the method request handler on the client - device_client.on_method_request_received = method_request_handler - - # Define behavior for halting the application - def stdin_listener(): - while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - # Run the stdin listener in the event loop - loop = asyncio.get_running_loop() - user_finished = loop.run_in_executor(None, stdin_listener) - - # Wait for user to indicate they are done listening for method calls - await user_finished - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/receive_message.py b/samples/async-hub-scenarios/receive_message.py deleted file mode 100644 index 5c966acd1..000000000 --- a/samples/async-hub-scenarios/receive_message.py +++ /dev/null @@ -1,59 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient - - -async def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # connect the client. - await device_client.connect() - - # define behavior for receiving a message - # NOTE: this could be a function or a coroutine - def message_received_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - print("content Type: {0}".format(message.content_type)) - print("") - - # set the message received handler on the client - device_client.on_message_received = message_received_handler - - # define behavior for halting the application - def stdin_listener(): - while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - # Run the stdin listener in the event loop - loop = asyncio.get_running_loop() - user_finished = loop.run_in_executor(None, stdin_listener) - - # Wait for user to indicate they are done listening for messages - await user_finished - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/receive_message_x509.py b/samples/async-hub-scenarios/receive_message_x509.py deleted file mode 100644 index bd01e92d6..000000000 --- a/samples/async-hub-scenarios/receive_message_x509.py +++ /dev/null @@ -1,67 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import X509 - - -async def main(): - hostname = os.getenv("HOSTNAME") - # The device that has been created on the portal using X509 CA signing or Self signing capabilities - device_id = os.getenv("DEVICE_ID") - - x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), - ) - - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_x509_certificate( - hostname=hostname, device_id=device_id, x509=x509 - ) - - await device_client.connect() - - # Define behavior for receiving a message - # NOTE: this could be a function or a coroutine - def message_received_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - - # Set the message received handler on the client - device_client.on_message_received = message_received_handler - - # Define behavior for halting the application - def stdin_listener(): - while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - # Run the stdin listener in the event loop - loop = asyncio.get_running_loop() - user_finished = loop.run_in_executor(None, stdin_listener) - - # Wait for user to indicate they are done listening for messages - await user_finished - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/receive_twin_desired_properties_patch.py b/samples/async-hub-scenarios/receive_twin_desired_properties_patch.py deleted file mode 100644 index e3c7e05b9..000000000 --- a/samples/async-hub-scenarios/receive_twin_desired_properties_patch.py +++ /dev/null @@ -1,54 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient - - -async def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # connect the client. - await device_client.connect() - - # define behavior for receiving a twin patch - # NOTE: this could be a function or a coroutine - def twin_patch_handler(patch): - print("the data in the desired properties patch was: {}".format(patch)) - - # set the twin patch handler on the client - device_client.on_twin_desired_properties_patch_received = twin_patch_handler - - # define behavior for halting the application - def stdin_listener(): - while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - # Run the stdin listener in the event loop - loop = asyncio.get_running_loop() - user_finished = loop.run_in_executor(None, stdin_listener) - - # Wait for user to indicate they are done listening for messages - await user_finished - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/recurring_telemetry.py b/samples/async-hub-scenarios/recurring_telemetry.py deleted file mode 100644 index ab877ad32..000000000 --- a/samples/async-hub-scenarios/recurring_telemetry.py +++ /dev/null @@ -1,55 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -import asyncio -import time -import uuid -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message - - -async def send_recurring_telemetry(device_client): - # Connect the client. - await device_client.connect() - - # Send recurring telemetry - i = 0 - while True: - i += 1 - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - print("sending message #" + str(i)) - await device_client.send_message(msg) - time.sleep(2) - - -def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - print("IoTHub Device Client Recurring Telemetry Sample") - print("Press Ctrl+C to exit") - loop = asyncio.get_event_loop() - try: - loop.run_until_complete(send_recurring_telemetry(device_client)) - except KeyboardInterrupt: - print("User initiated exit") - except Exception: - print("Unexpected exception!") - raise - finally: - loop.run_until_complete(device_client.shutdown()) - loop.close() - - -if __name__ == "__main__": - main() diff --git a/samples/async-hub-scenarios/send_message.py b/samples/async-hub-scenarios/send_message.py deleted file mode 100644 index 236fe2b01..000000000 --- a/samples/async-hub-scenarios/send_message.py +++ /dev/null @@ -1,50 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -import uuid -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message - -messages_to_send = 10 - - -async def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # Connect the client. - await device_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - await device_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/send_message_over_websockets.py b/samples/async-hub-scenarios/send_message_over_websockets.py deleted file mode 100644 index b563ae672..000000000 --- a/samples/async-hub-scenarios/send_message_over_websockets.py +++ /dev/null @@ -1,36 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient - - -async def main(): - # Fetch the connection string from an environment variable - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - # Create instance of the device client using the connection string - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str, websockets=True) - - # We do not need to call device_client.connect(), since it will be connected when we send a message. - - # Send a single message - print("Sending message...") - await device_client.send_message("This is a message that is being sent") - print("Message successfully sent!") - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/send_message_via_module_x509.py b/samples/async-hub-scenarios/send_message_via_module_x509.py deleted file mode 100644 index 6d6e1ee8c..000000000 --- a/samples/async-hub-scenarios/send_message_via_module_x509.py +++ /dev/null @@ -1,62 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import uuid -from azure.iot.device.aio import IoTHubModuleClient -from azure.iot.device import Message, X509 -import asyncio - - -messages_to_send = 10 - - -async def main(): - hostname = os.getenv("HOSTNAME") - - # The device having a certain module that has been created on the portal - # using X509 CA signing or Self signing capabilities - - device_id = os.getenv("DEVICE_ID") - module_id = os.getenv("MODULE_ID") - - x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), - ) - - module_client = IoTHubModuleClient.create_from_x509_certificate( - hostname=hostname, x509=x509, device_id=device_id, module_id=module_id - ) - - # Connect the client. - await module_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - await module_client.send_message(msg) - print("done sending message #" + str(i)) - - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # Finally, shut down the client - await module_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/send_message_via_proxy.py b/samples/async-hub-scenarios/send_message_via_proxy.py deleted file mode 100644 index ab25b00e9..000000000 --- a/samples/async-hub-scenarios/send_message_via_proxy.py +++ /dev/null @@ -1,56 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -import uuid -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message, ProxyOptions - -messages_to_send = 10 - - -async def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - proxy_opts = ProxyOptions( - proxy_type="HTTP", proxy_addr="127.0.0.1", proxy_port=8888 # localhost - ) - - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_connection_string( - conn_str, websockets=True, proxy_options=proxy_opts - ) - - # Connect the client. - await device_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - await device_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/send_message_x509.py b/samples/async-hub-scenarios/send_message_x509.py deleted file mode 100644 index cbeba73b7..000000000 --- a/samples/async-hub-scenarios/send_message_x509.py +++ /dev/null @@ -1,60 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import uuid -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import Message, X509 -import asyncio - - -messages_to_send = 10 - - -async def main(): - hostname = os.getenv("HOSTNAME") - # The device that has been created on the portal using X509 CA signing or Self signing capabilities - device_id = os.getenv("DEVICE_ID") - - x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), - ) - - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_x509_certificate( - hostname=hostname, device_id=device_id, x509=x509 - ) - - # Connect the client. - await device_client.connect() - - async def send_test_message(i): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - await device_client.send_message(msg) - print("done sending message #" + str(i)) - - # send `messages_to_send` messages in parallel - await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/update_twin_reported_properties.py b/samples/async-hub-scenarios/update_twin_reported_properties.py deleted file mode 100644 index 507300e43..000000000 --- a/samples/async-hub-scenarios/update_twin_reported_properties.py +++ /dev/null @@ -1,35 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -import random -from azure.iot.device.aio import IoTHubDeviceClient - - -async def main(): - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # connect the client. - await device_client.connect() - - # update the reported properties - reported_properties = {"temperature": random.randint(320, 800) / 10} - print("Setting reported temperature to {}".format(reported_properties["temperature"])) - await device_client.patch_twin_reported_properties(reported_properties) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/async-hub-scenarios/upload_to_blob.py b/samples/async-hub-scenarios/upload_to_blob.py deleted file mode 100644 index 0fb6ff8da..000000000 --- a/samples/async-hub-scenarios/upload_to_blob.py +++ /dev/null @@ -1,135 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import uuid -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient -import pprint -from azure.storage.blob import BlobClient -from azure.core.exceptions import ResourceExistsError -import logging - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -""" -Welcome to the Upload to Blob sample for the Azure IoT Device Library for Python. To use this sample you must have azure.storage.blob installed in your python environment. -To do this, you can run: - - $ pip install azure.storage.blob - -This sample covers using the following Device Client APIs: - - get_storage_info_for_blob - - used to get relevant information from IoT Hub about a linked Storage Account, including - a hostname, a container name, a blob name, and a sas token. Additionally it returns a correlation_id - which is used in the notify_blob_upload_status, since the correlation_id is IoT Hub's way of marking - which blob you are working on. - notify_blob_upload_status - - used to notify IoT Hub of the status of your blob storage operation. This uses the correlation_id obtained - by the get_storage_info_for_blob task, and will tell IoT Hub to notify any service that might be listening for a notification on the - status of the file upload task. - -You can learn more about File Upload with IoT Hub here: - -https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-file-upload - -""" -IOTHUB_DEVICE_CONNECTION_STRING = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - -async def upload_via_storage_blob(blob_info): - """Helper function written to perform Storage Blob V12 Upload Tasks - - Arguments: - blob_info - an object containing the information needed to generate a sas_url for creating a blob client - - Returns: - status of blob upload operation, in the storage provided structure. - """ - - print("Azure Blob storage v12 - Python quickstart sample") - sas_url = "https://{}/{}/{}{}".format( - blob_info["hostName"], - blob_info["containerName"], - blob_info["blobName"], - blob_info["sasToken"], - ) - blob_client = BlobClient.from_blob_url(sas_url) - - # The following file code can be replaced with simply a sample file in a directory. - - # Create a file in local Documents directory to upload and download - local_file_name = "data/quickstart" + str(uuid.uuid4()) + ".txt" - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), local_file_name) - - # Write text to the file - if not os.path.exists(os.path.dirname(filename)): - os.makedirs(os.path.dirname(filename)) - file = open(filename, "w") - file.write("Hello, World!") - file.close() - - # Perform the actual upload for the data. - print("\nUploading to Azure Storage as blob:\n\t" + local_file_name) - # # Upload the created file - with open(filename, "rb") as data: - result = blob_client.upload_blob(data) - - return result - - -async def main(): - conn_str = IOTHUB_DEVICE_CONNECTION_STRING - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # Connect the client. - await device_client.connect() - - # get the Storage SAS information from IoT Hub. - blob_name = "fakeBlobName12" - storage_info = await device_client.get_storage_info_for_blob(blob_name) - result = {"status_code": -1, "status_description": "N/A"} - - # Using the Storage Blob V12 API, perform the blob upload. - try: - upload_result = await upload_via_storage_blob(storage_info) - if hasattr(upload_result, "error_code"): - result = { - "status_code": upload_result.error_code, - "status_description": "Storage Blob Upload Error", - } - else: - result = {"status_code": 200, "status_description": ""} - except ResourceExistsError as ex: - if ex.status_code: - result = {"status_code": ex.status_code, "status_description": ex.reason} - else: - print("Failed with Exception: {}", ex) - result = {"status_code": 400, "status_description": ex.message} - - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(result) - - if result["status_code"] == 200: - await device_client.notify_blob_upload_status( - storage_info["correlationId"], True, result["status_code"], result["status_description"] - ) - else: - await device_client.notify_blob_upload_status( - storage_info["correlationId"], - False, - result["status_code"], - result["status_description"], - ) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/samples/async-hub-scenarios/upload_to_blob_x509.py b/samples/async-hub-scenarios/upload_to_blob_x509.py deleted file mode 100644 index 23aa05b90..000000000 --- a/samples/async-hub-scenarios/upload_to_blob_x509.py +++ /dev/null @@ -1,149 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import uuid -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device import X509 -import pprint -from azure.storage.blob import BlobClient -from azure.core.exceptions import ResourceExistsError -import logging - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -""" -Welcome to the Upload to Blob sample for the Azure IoT Device Library for Python. To use this sample you must have azure.storage.blob installed in your python environment. -To do this, you can run: - - $ pip install azure.storage.blob - -This sample covers using the following Device Client APIs: - - get_storage_info_for_blob - - used to get relevant information from IoT Hub about a linked Storage Account, including - a hostname, a container name, a blob name, and a sas token. Additionally it returns a correlation_id - which is used in the notify_blob_upload_status, since the correlation_id is IoT Hub's way of marking - which blob you are working on. - notify_blob_upload_status - - used to notify IoT Hub of the status of your blob storage operation. This uses the correlation_id obtained - by the get_storage_info_for_blob task, and will tell IoT Hub to notify any service that might be listening for a notification on the - status of the file upload task. - -You can learn more about File Upload with IoT Hub here: - -https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-file-upload - -""" - -IOTHUB_HOSTNAME = os.getenv("IOTHUB_HOSTNAME") -IOTHUB_DEVICE_ID = os.getenv("IOTHUB_DEVICE_ID") - -X509_CERT_FILE = os.getenv("X509_CERT_FILE") -X509_KEY_FILE = os.getenv("X509_KEY_FILE") -X509_PASS_PHRASE = os.getenv("X509_PASS_PHRASE") - -# Host is in format ".azure-devices.net" - - -async def upload_via_storage_blob(blob_info): - """Helper function written to perform Storage Blob V12 Upload Tasks - - Arguments: - blob_info - an object containing the information needed to generate a sas_url for creating a blob client - - Returns: - status of blob upload operation, in the storage provided structure. - """ - - print("Azure Blob storage v12 - Python quickstart sample") - sas_url = "https://{}/{}/{}{}".format( - blob_info["hostName"], - blob_info["containerName"], - blob_info["blobName"], - blob_info["sasToken"], - ) - blob_client = BlobClient.from_blob_url(sas_url) - - # The following file code can be replaced with simply a sample file in a directory. - - # Create a file in local Documents directory to upload and download - local_file_name = "data/quickstart" + str(uuid.uuid4()) + ".txt" - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), local_file_name) - - # Write text to the file - if not os.path.exists(os.path.dirname(filename)): - os.makedirs(os.path.dirname(filename)) - file = open(filename, "w") - file.write("Hello, World!") - file.close() - - # Perform the actual upload for the data. - print("\nUploading to Azure Storage as blob:\n\t" + local_file_name) - # # Upload the created file - with open(filename, "rb") as data: - result = blob_client.upload_blob(data) - - return result - - -async def main(): - hostname = IOTHUB_HOSTNAME - device_id = IOTHUB_DEVICE_ID - x509 = X509(cert_file=X509_CERT_FILE, key_file=X509_KEY_FILE, pass_phrase=X509_PASS_PHRASE) - - # Create the Device Client. - device_client = IoTHubDeviceClient.create_from_x509_certificate( - hostname=hostname, device_id=device_id, x509=x509 - ) - - # Connect the client. - await device_client.connect() - - # get the Storage SAS information from IoT Hub. - blob_name = "fakeBlobName12" - storage_info = await device_client.get_storage_info_for_blob(blob_name) - result = {"status_code": -1, "status_description": "N/A"} - - # Using the Storage Blob V12 API, perform the blob upload. - try: - upload_result = await upload_via_storage_blob(storage_info) - if hasattr(upload_result, "error_code"): - result = { - "status_code": upload_result.error_code, - "status_description": "Storage Blob Upload Error", - } - else: - result = {"status_code": 200, "status_description": ""} - except ResourceExistsError as ex: - if ex.status_code: - result = {"status_code": ex.status_code, "status_description": ex.reason} - else: - print("Failed with Exception: {}", ex) - result = {"status_code": 400, "status_description": ex.message} - - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(result) - if result["status_code"] == 200: - await device_client.notify_blob_upload_status( - storage_info["correlationId"], True, result["status_code"], result["status_description"] - ) - else: - await device_client.notify_blob_upload_status( - storage_info["correlationId"], - False, - result["status_code"], - result["status_description"], - ) - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/samples/async-hub-scenarios/use_custom_sastoken.py b/samples/async-hub-scenarios/use_custom_sastoken.py deleted file mode 100644 index 656d7f5f1..000000000 --- a/samples/async-hub-scenarios/use_custom_sastoken.py +++ /dev/null @@ -1,75 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient - - -# NOTE: This code needs to be completed in order to work. -# Fill out the get_new_sastoken() method to return a NEW custom sastoken from your solution. -# It must return a unique value each time it is called. -def get_new_sastoken(): - pass - - -async def main(): - - # Get a sastoken you generated - sastoken = get_new_sastoken() - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_sastoken(sastoken) - - # connect the client. - await device_client.connect() - - # define behavior for receiving a message - # NOTE: this could be a function or a coroutine - def message_received_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - print("content Type: {0}".format(message.content_type)) - print("") - - # define behavior for updating sastoken - async def sastoken_update_handler(): - print("Updating SAS Token...") - sastoken = get_new_sastoken() - await device_client.update_sastoken(sastoken) - print("SAS Token updated") - - # set the message received handler on the client - device_client.on_message_received = message_received_handler - # set the sastoken update handler on the client - device_client.on_new_sastoken_required = sastoken_update_handler - - # define behavior for halting the application - def stdin_listener(): - while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - # Run the stdin listener in the event loop - loop = asyncio.get_running_loop() - user_finished = loop.run_in_executor(None, stdin_listener) - - # Wait for user to indicate they are done listening for messages - await user_finished - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/iothub_multi_feature.py b/samples/iothub_multi_feature.py new file mode 100644 index 000000000..d4f00c8cf --- /dev/null +++ b/samples/iothub_multi_feature.py @@ -0,0 +1,113 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This sample demonstrates a more complex scenario for use of the IoTHubSession. +This application both sends and receives data, and can be controlled via direct methods. +If the connection drops, it will try to establish one again until the user exits. +""" + +import asyncio +import os +from azure.iot.device import ( + IoTHubSession, + DirectMethodResponse, + MQTTConnectionDroppedError, + MQTTConnectionFailedError, +) + + +CONNECTION_STRING = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") +TOTAL_MESSAGES_SENT = 0 +TOTAL_MESSAGES_RECEIVED = 0 + + +def do_foo(): + print("FOO!") + return True + + +def do_bar(): + print("BAR!") + return True + + +async def send_telemetry(session): + global TOTAL_MESSAGES_SENT + while True: + TOTAL_MESSAGES_SENT += 1 + print("Sending Message #{}.".format(TOTAL_MESSAGES_SENT)) + await session.send_message("Message #{}".format(TOTAL_MESSAGES_SENT)) + print("Send complete") + await asyncio.sleep(5) + + +async def receive_c2d_messages(session): + global TOTAL_MESSAGES_RECEIVED + async with session.messages() as messages: + print("Waiting to receive messages...") + async for message in messages: + TOTAL_MESSAGES_RECEIVED += 1 + print("Message received with payload: {}".format(message.payload)) + + +async def receive_direct_method_requests(session): + async with session.direct_method_requests() as method_requests: + async for method_request in method_requests: + if method_request.name == "foo": + print("Direct Method request received for 'foo'. Invoking.") + result = do_foo() + payload = {"result": result} + status = 200 + print("'foo' was completed with result: {}".format(result)) + elif method_request.name == "bar": + print("Direct Method request received for 'bar'. Invoking.") + result = do_bar() + payload = {"result": result} + status = 204 + print("'bar' was completed with result: {}".format(result)) + else: + payload = {} + status = 400 + print("Unknown Direct Method request received: {}".format(method_request.name)) + method_response = DirectMethodResponse.create_from_method_request( + method_request, status, payload + ) + await session.send_direct_method_response(method_response) + + +async def main(): + print("Starting multi-feature sample") + print("Press Ctrl-C to exit") + while True: + try: + print("Connecting to IoT Hub...") + async with IoTHubSession.from_connection_string(CONNECTION_STRING) as session: + print("Connected to IoT Hub") + await asyncio.gather( + send_telemetry(session), + receive_c2d_messages(session), + receive_direct_method_requests(session), + ) + + except MQTTConnectionDroppedError: + # Connection has been lost. Reconnect on next pass of loop. + print("Dropped connection. Reconnecting in 1 second") + await asyncio.sleep(1) + except MQTTConnectionFailedError: + # Connection failed to be established. Retry on next pass of loop. + print("Could not connect. Retrying in 10 seconds") + await asyncio.sleep(10) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting.") + finally: + print("Sent {} messages in total.".format(TOTAL_MESSAGES_SENT)) + print("Received {} messages in total.".format(TOTAL_MESSAGES_RECEIVED)) diff --git a/samples/iothub_simple_c2d.py b/samples/iothub_simple_c2d.py new file mode 100644 index 000000000..f4055e83c --- /dev/null +++ b/samples/iothub_simple_c2d.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This sample demonstrates a simple cloud to device receive using an IoTHubSession.""" + +import asyncio +import os +from azure.iot.device import IoTHubSession, MQTTConnectionDroppedError, MQTTConnectionFailedError + +CONNECTION_STRING = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") +TOTAL_MESSAGES_RECEIVED = 0 + + +async def main(): + global TOTAL_MESSAGES_RECEIVED + print("Starting C2D sample") + print("Press Ctrl-C to exit") + try: + print("Connecting to IoT Hub...") + async with IoTHubSession.from_connection_string(CONNECTION_STRING) as session: + print("Connected to IoT Hub") + async with session.messages() as messages: + print("Waiting to receive messages...") + async for message in messages: + TOTAL_MESSAGES_RECEIVED += 1 + print("Message received with payload: {}".format(message.payload)) + + except MQTTConnectionDroppedError: + # Connection has been lost. + print("Dropped connection. Exiting") + except MQTTConnectionFailedError: + # Connection failed to be established. + print("Could not connect. Exiting") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting") + finally: + print("Received {} messages in total".format(TOTAL_MESSAGES_RECEIVED)) diff --git a/samples/iothub_simple_c2d_with_reconnect.py b/samples/iothub_simple_c2d_with_reconnect.py new file mode 100644 index 000000000..fce67cfae --- /dev/null +++ b/samples/iothub_simple_c2d_with_reconnect.py @@ -0,0 +1,51 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This sample demonstrates a simple cloud to device receive using an IoTHubSession. +If the connection drops, it will try to establish one again until the user exits. +""" + +import asyncio +import os +from azure.iot.device import IoTHubSession, MQTTConnectionDroppedError, MQTTConnectionFailedError + +CONNECTION_STRING = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") +TOTAL_MESSAGES_RECEIVED = 0 + + +async def main(): + global TOTAL_MESSAGES_RECEIVED + print("Starting C2D sample") + print("Press Ctrl-C to exit") + while True: + try: + print("Connecting to IoT Hub...") + async with IoTHubSession.from_connection_string(CONNECTION_STRING) as session: + print("Connected to IoT Hub") + async with session.messages() as messages: + print("Waiting to receive messages...") + async for message in messages: + TOTAL_MESSAGES_RECEIVED += 1 + print("Message received with payload: {}".format(message.payload)) + + except MQTTConnectionDroppedError: + # Connection has been lost. Reconnect on next pass of loop. + print("Dropped connection. Reconnecting in 1 second") + await asyncio.sleep(1) + except MQTTConnectionFailedError: + # Connection failed to be established. Retry on next pass of loop. + print("Could not connect. Retrying in 10 seconds") + await asyncio.sleep(10) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting") + finally: + print("Received {} messages in total".format(TOTAL_MESSAGES_RECEIVED)) diff --git a/samples/iothub_simple_telemetry.py b/samples/iothub_simple_telemetry.py new file mode 100644 index 000000000..2f675f20c --- /dev/null +++ b/samples/iothub_simple_telemetry.py @@ -0,0 +1,53 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +""" +This sample demonstrates a simple recurring telemetry using an IoTHubSession + +It's set to be used in the following MS Learn Tutorial: +https://learn.microsoft.com/en-us/azure/iot-develop/quickstart-send-telemetry-iot-hub?pivots=programming-language-python +""" + +import asyncio +import os +from azure.iot.device import IoTHubSession, MQTTConnectionDroppedError, MQTTConnectionFailedError + +CONNECTION_STRING = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") +TOTAL_MESSAGES_SENT = 0 + + +async def main(): + global TOTAL_MESSAGES_SENT + print("Starting telemetry sample") + print("Press Ctrl-C to exit") + try: + global TOTAL_MESSAGES_SENT + print("Connecting to IoT Hub...") + async with IoTHubSession.from_connection_string(CONNECTION_STRING) as session: + print("Connected to IoT Hub") + while True: + TOTAL_MESSAGES_SENT += 1 + print("Sending Message #{}...".format(TOTAL_MESSAGES_SENT)) + await session.send_message("Message #{}".format(TOTAL_MESSAGES_SENT)) + print("Send Complete") + await asyncio.sleep(5) + + except MQTTConnectionDroppedError: + # Connection has been lost. + print("Dropped connection. Exiting") + except MQTTConnectionFailedError: + # Connection failed to be established. + print("Could not connect. Exiting") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting") + finally: + print("Sent {} messages in total".format(TOTAL_MESSAGES_SENT)) diff --git a/samples/iothub_simple_telemetry_with_reconnect.py b/samples/iothub_simple_telemetry_with_reconnect.py new file mode 100644 index 000000000..105ebb4ed --- /dev/null +++ b/samples/iothub_simple_telemetry_with_reconnect.py @@ -0,0 +1,52 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This sample demonstrates a simple recurring telemetry using an IoTHubSession. +If the connection drops, it will try to establish one again until the user exits. +""" + +import asyncio +import os +from azure.iot.device import IoTHubSession, MQTTConnectionDroppedError, MQTTConnectionFailedError + +CONNECTION_STRING = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") +TOTAL_MESSAGES_SENT = 0 + + +async def main(): + global TOTAL_MESSAGES_SENT + print("Starting telemetry sample") + print("Press Ctrl-C to exit") + while True: + try: + print("Connecting to IoT Hub...") + async with IoTHubSession.from_connection_string(CONNECTION_STRING) as session: + print("Connected to IoT Hub") + while True: + TOTAL_MESSAGES_SENT += 1 + print("Sending Message #{}...".format(TOTAL_MESSAGES_SENT)) + await session.send_message("Message #{}".format(TOTAL_MESSAGES_SENT)) + print("Send Complete") + await asyncio.sleep(5) + + except MQTTConnectionDroppedError: + # Connection has been lost. Reconnect on next pass of loop. + print("Dropped connection. Reconnecting in 1 second") + await asyncio.sleep(1) + except MQTTConnectionFailedError: + # Connection failed to be established. Retry on next pass of loop. + print("Could not connect. Retrying in 10 seconds") + await asyncio.sleep(10) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting.") + finally: + print("Sent {} messages in total.".format(TOTAL_MESSAGES_SENT)) diff --git a/samples/media/codespace.png b/samples/media/codespace.png deleted file mode 100644 index 32355e7e4..000000000 Binary files a/samples/media/codespace.png and /dev/null differ diff --git a/samples/pnp/README.md b/samples/pnp/README.md deleted file mode 100644 index a2ddae1b4..000000000 --- a/samples/pnp/README.md +++ /dev/null @@ -1,62 +0,0 @@ ---- -page_type: sample -description: "A set of Python samples that show how a device that uses the IoT Plug and Play conventions interacts with either IoT Hub or IoT Central." -languages: -- python -products: -- azure-iot-hub -- azure-iot-central -- azure-iot-pnp -urlFragment: azure-iot-pnp-device-samples-for-python ---- - -# IoT Plug And Play device samples - -[![Documentation](../../doc/images/docs-link-buttons/azure-documentation.svg)](https://docs.microsoft.com/azure/iot-develop/) - -These samples demonstrate how a device that follows the [IoT Plug and Play conventions](https://docs.microsoft.com/azure/iot-pnp/concepts-convention) interacts with IoT Hub or IoT Central, to: - -- Send telemetry. -- Update read-only and read-write properties. -- Respond to command invocation. - -The samples demonstrate two scenarios: - -- An IoT Plug and Play device that implements the [Thermostat](https://devicemodels.azure.com/dtmi/com/example/thermostat-1.json) model. This model has a single interface that defines telemetry, read-only and read-write properties, and commands. -- An IoT Plug and Play device that implements the [Temperature controller](https://devicemodels.azure.com/dtmi/com/example/temperaturecontroller-2.json) model. This model uses multiple components: - - The top-level interface defines telemetry, read-only property and commands. - - The model includes two [Thermostat](https://devicemodels.azure.com/dtmi/com/example/thermostat-1.json) components, and a [device information](https://devicemodels.azure.com/dtmi/azure/devicemanagement/deviceinformation-1.json) component. - -## Quickstarts and tutorials - -To learn more about how to configure and run the Thermostat device sample with IoT Hub, see [Quickstart: Connect a sample IoT Plug and Play device application running on Linux or Windows to IoT Hub](https://docs.microsoft.com/azure/iot-pnp/quickstart-connect-device?pivots=programming-language-python). - -To learn more about how to configure and run the Temperature Controller device sample with: - -- IoT Hub, see [Tutorial: Connect an IoT Plug and Play multiple component device application running on Linux or Windows to IoT Hub](https://docs.microsoft.com/azure/iot-pnp/tutorial-multiple-components?pivots=programming-language-python) -- IoT Central, see [Tutorial: Create and connect a client application to your Azure IoT Central application](https://docs.microsoft.com/azure/iot-central/core/tutorial-connect-device?pivots=programming-language-python) - -## Configuring the samples - -Both samples use environment variables to retrieve configuration. - -* If you are using a connection string to authenticate: - * set IOTHUB_DEVICE_SECURITY_TYPE="connectionString" - * set IOTHUB_DEVICE_CONNECTION_STRING="\" - -* If you are using a DPS enrollment group to authenticate: - * set IOTHUB_DEVICE_SECURITY_TYPE="DPS" - * set IOTHUB_DEVICE_DPS_ID_SCOPE="\" - * set IOTHUB_DEVICE_DPS_DEVICE_ID="\" - * set IOTHUB_DEVICE_DPS_DEVICE_KEY="\" - * set IOTHUB_DEVICE_DPS_ENDPOINT="\" - -## Caveats - -* Azure IoT Plug and Play is only supported for MQTT and MQTT over WebSockets for the Azure IoT Python Device SDK. Modifying these samples to use AMQP, AMQP over WebSockets, or HTTP protocols **will not work**. - -* When the thermostat receives a desired temperature, it has no actual affect on the current temperature. - -* The command `getMaxMinReport` allows the application to specify statistics of the temperature since a given date. To keep the sample simple, we ignore this field and instead return statistics from the some portion of the lifecycle of the executable. - -* The temperature controller implements a command named `reboot` which takes a request payload indicating the delay in seconds. The sample will ignore doing anything on this command. diff --git a/samples/pnp/pnp_helper.py b/samples/pnp/pnp_helper.py deleted file mode 100644 index 39425530a..000000000 --- a/samples/pnp/pnp_helper.py +++ /dev/null @@ -1,132 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -This module knows how to convert device SDK functionality into a plug and play functionality. -These methods formats the telemetry, methods, properties to plug and play relevant telemetry, -command requests and pnp properties. -""" -from azure.iot.device import Message -import json - - -class PnpProperties(object): - def __init__(self, top_key, **kwargs): - self._top_key = top_key - for name in kwargs: - setattr(self, name, kwargs[name]) - - def _to_value_dict(self): - all_attrs = list((x for x in self.__dict__ if x != "_top_key")) - inner = {key: {"value": getattr(self, key)} for key in all_attrs} - return inner - - def _to_simple_dict(self): - all_simple_attrs = list((x for x in self.__dict__ if x != "_top_key")) - inner = {key: getattr(self, key) for key in all_simple_attrs} - return inner - - -def create_telemetry(telemetry_msg, component_name=None): - """ - Function to create telemetry for a plug and play device. This function will take the raw telemetry message - in the form of a dictionary from the user and then create a plug and play specific message. - :param telemetry_msg: A dictionary of items to be sent as telemetry. - :param component_name: The name of the device like "sensor" - :return: The message. - """ - msg = Message(json.dumps(telemetry_msg)) - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - if component_name: - msg.custom_properties["$.sub"] = component_name - return msg - - -def create_reported_properties(component_name=None, **prop_kwargs): - """ - Function to create properties for a plug and play device. This method will take in the user properties passed as - key word arguments and then creates plug and play specific reported properties. - :param component_name: The name of the component. Like "deviceinformation" or "sdkinformation" - :param prop_kwargs: The user passed keyword arguments which are the properties that the user wants to update. - :return: The dictionary of properties. - """ - if component_name: - print("Updating pnp properties for {component_name}".format(component_name=component_name)) - else: - print("Updating pnp properties for root interface") - prop_object = PnpProperties(component_name, **prop_kwargs) - inner_dict = prop_object._to_simple_dict() - if component_name: - inner_dict["__t"] = "c" - prop_dict = {} - prop_dict[component_name] = inner_dict - else: - prop_dict = inner_dict - - print(prop_dict) - return prop_dict - - -def create_response_payload_with_status(command_request, method_name, create_user_response=None): - """ - Helper method to create the payload for responding to a command request. - This method is used for all method responses unless the user provides another - method to construct responses to specific command requests. - :param command_request: The command request for which the response is being sent. - :param method_name: The method name for which we are responding to. - :param create_user_response: Function to create user specific response. - :return: The response payload. - """ - if method_name: - response_status = 200 - else: - response_status = 404 - - if not create_user_response: - result = True if method_name else False - data = "executed " + method_name if method_name else "unknown method" - response_payload = {"result": result, "data": data} - else: - response_payload = create_user_response(command_request.payload) - - return (response_status, response_payload) - - -def create_reported_properties_from_desired(patch): - """ - Function to create properties for a plug and play device. This method will take in the desired properties patch. - and then create plug and play specific reported properties. - :param patch: The patch of desired properties. - :return: The dictionary of properties. - """ - print("the data in the desired properties patch was: {}".format(patch)) - - ignore_keys = ["__t", "$version"] - component_prefix = list(patch.keys())[0] - values = patch[component_prefix] - print("Values received are :-") - print(values) - - version = patch["$version"] - inner_dict = {} - - for prop_name, prop_value in values.items(): - if prop_name in ignore_keys: - continue - else: - inner_dict["ac"] = 200 - inner_dict["ad"] = "Successfully executed patch" - inner_dict["av"] = version - inner_dict["value"] = prop_value - values[prop_name] = inner_dict - - properties_dict = dict() - if component_prefix: - properties_dict[component_prefix] = values - else: - properties_dict = values - - return properties_dict diff --git a/samples/pnp/simple_thermostat.py b/samples/pnp/simple_thermostat.py deleted file mode 100644 index 2ea144447..000000000 --- a/samples/pnp/simple_thermostat.py +++ /dev/null @@ -1,342 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -import asyncio -import random -import logging -import json - -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device.aio import ProvisioningDeviceClient -from azure.iot.device import Message, MethodResponse -from datetime import timedelta, datetime - -logging.basicConfig(level=logging.ERROR) - -# The device "Thermostat" that is getting implemented using the above interfaces. -# This id can change according to the company the user is from -# and the name user wants to call this Plug and Play device -model_id = "dtmi:com:example:Thermostat;1" - -##################################################### -# GLOBAL THERMOSTAT VARIABLES -max_temp = None -min_temp = None -avg_temp_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -moving_window_size = len(avg_temp_list) -target_temperature = None - - -##################################################### -# COMMAND HANDLERS : User will define these handlers -# depending on what commands the DTMI defines - - -async def reboot_handler(values): - global max_temp - global min_temp - global avg_temp_list - global target_temperature - if values and type(values) == int: - print("Rebooting after delay of {delay} secs".format(delay=values)) - asyncio.sleep(values) - max_temp = None - min_temp = None - for idx in range(len(avg_temp_list)): - avg_temp_list[idx] = 0 - target_temperature = None - print("maxTemp {}, minTemp {}".format(max_temp, min_temp)) - print("Done rebooting") - - -async def max_min_handler(values): - if values: - print( - "Will return the max, min and average temperature from the specified time {since} to the current time".format( - since=values - ) - ) - print("Done generating") - - -# END COMMAND HANDLERS -##################################################### - -##################################################### -# CREATE RESPONSES TO COMMANDS - - -def create_max_min_report_response(values): - """ - An example function that can create a response to the "getMaxMinReport" command request the way the user wants it. - Most of the times response is created by a helper function which follows a generic pattern. - This should be only used when the user wants to give a detailed response back to the Hub. - :param values: The values that were received as part of the request. - """ - response_dict = { - "maxTemp": max_temp, - "minTemp": min_temp, - "avgTemp": sum(avg_temp_list) / moving_window_size, - "startTime": (datetime.now() - timedelta(0, moving_window_size * 8)).isoformat(), - "endTime": datetime.now().isoformat(), - } - # serialize response dictionary into a JSON formatted str - response_payload = json.dumps(response_dict, default=lambda o: o.__dict__, sort_keys=True) - print(response_payload) - return response_payload - - -def create_reboot_response(values): - response = {"result": True, "data": "reboot succeeded"} - return response - - -# END CREATE RESPONSES TO COMMANDS -##################################################### - -##################################################### -# TELEMETRY TASKS - - -async def send_telemetry_from_thermostat(device_client, telemetry_msg): - msg = Message(json.dumps(telemetry_msg)) - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - print("Sent message") - await device_client.send_message(msg) - - -# END TELEMETRY TASKS -##################################################### - -##################################################### -# CREATE COMMAND AND PROPERTY LISTENERS - - -async def execute_command_listener( - device_client, method_name, user_command_handler, create_user_response_handler -): - while True: - if method_name: - command_name = method_name - else: - command_name = None - - command_request = await device_client.receive_method_request(command_name) - print("Command request received with payload") - print(command_request.payload) - - values = {} - if not command_request.payload: - print("Payload was empty.") - else: - values = command_request.payload - - await user_command_handler(values) - - response_status = 200 - response_payload = create_user_response_handler(values) - - command_response = MethodResponse.create_from_method_request( - command_request, response_status, response_payload - ) - - try: - await device_client.send_method_response(command_response) - except Exception: - print("responding to the {command} command failed".format(command=method_name)) - - -async def execute_property_listener(device_client): - ignore_keys = ["__t", "$version"] - while True: - patch = await device_client.receive_twin_desired_properties_patch() # blocking call - - print("the data in the desired properties patch was: {}".format(patch)) - - version = patch["$version"] - prop_dict = {} - - for prop_name, prop_value in patch.items(): - if prop_name in ignore_keys: - continue - else: - prop_dict[prop_name] = { - "ac": 200, - "ad": "Successfully executed patch", - "av": version, - "value": prop_value, - } - - await device_client.patch_twin_reported_properties(prop_dict) - - -# END COMMAND AND PROPERTY LISTENERS -##################################################### - -##################################################### -# An # END KEYBOARD INPUT LISTENER to quit application - - -def stdin_listener(): - """ - Listener for quitting the sample - """ - while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - -# END KEYBOARD INPUT LISTENER -##################################################### - - -##################################################### -# PROVISION DEVICE -async def provision_device(provisioning_host, id_scope, registration_id, symmetric_key, model_id): - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=symmetric_key, - ) - provisioning_device_client.provisioning_payload = {"modelId": model_id} - return await provisioning_device_client.register() - - -##################################################### -# MAIN STARTS -async def main(): - switch = os.getenv("IOTHUB_DEVICE_SECURITY_TYPE") - if switch == "DPS": - provisioning_host = ( - os.getenv("IOTHUB_DEVICE_DPS_ENDPOINT") - if os.getenv("IOTHUB_DEVICE_DPS_ENDPOINT") - else "global.azure-devices-provisioning.net" - ) - id_scope = os.getenv("IOTHUB_DEVICE_DPS_ID_SCOPE") - registration_id = os.getenv("IOTHUB_DEVICE_DPS_DEVICE_ID") - symmetric_key = os.getenv("IOTHUB_DEVICE_DPS_DEVICE_KEY") - - registration_result = await provision_device( - provisioning_host, id_scope, registration_id, symmetric_key, model_id - ) - - if registration_result.status == "assigned": - print("Device was assigned") - print(registration_result.registration_state.assigned_hub) - print(registration_result.registration_state.device_id) - - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=symmetric_key, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - product_info=model_id, - ) - else: - raise RuntimeError( - "Could not provision device. Aborting Plug and Play device connection." - ) - - elif switch == "connectionString": - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - print("Connecting using Connection String " + conn_str) - device_client = IoTHubDeviceClient.create_from_connection_string( - conn_str, product_info=model_id - ) - else: - raise RuntimeError( - "At least one choice needs to be made for complete functioning of this sample." - ) - - # Connect the client. - await device_client.connect() - - ################################################ - # Set and read desired property (target temperature) - - max_temp = 10.96 # Initial Max Temp otherwise will not pass certification - await device_client.patch_twin_reported_properties({"maxTempSinceLastReboot": max_temp}) - - ################################################ - # Register callback and Handle command (reboot) - print("Listening for command requests and property updates") - - listeners = asyncio.gather( - execute_command_listener( - device_client, - method_name="reboot", - user_command_handler=reboot_handler, - create_user_response_handler=create_reboot_response, - ), - execute_command_listener( - device_client, - method_name="getMaxMinReport", - user_command_handler=max_min_handler, - create_user_response_handler=create_max_min_report_response, - ), - execute_property_listener(device_client), - ) - - ################################################ - # Send telemetry (current temperature) - - async def send_telemetry(): - print("Sending telemetry for temperature") - global max_temp - global min_temp - current_avg_idx = 0 - - while True: - current_temp = random.randrange(10, 50) # Current temperature in Celsius - if not max_temp: - max_temp = current_temp - elif current_temp > max_temp: - max_temp = current_temp - - if not min_temp: - min_temp = current_temp - elif current_temp < min_temp: - min_temp = current_temp - - avg_temp_list[current_avg_idx] = current_temp - current_avg_idx = (current_avg_idx + 1) % moving_window_size - - temperature_msg1 = {"temperature": current_temp} - await send_telemetry_from_thermostat(device_client, temperature_msg1) - await asyncio.sleep(8) - - send_telemetry_task = asyncio.create_task(send_telemetry()) - - # Run the stdin listener in the event loop - loop = asyncio.get_running_loop() - user_finished = loop.run_in_executor(None, stdin_listener) - # # Wait for user to indicate they are done listening for method calls - await user_finished - - if not listeners.done(): - listeners.set_result("DONE") - - listeners.cancel() - - send_telemetry_task.cancel() - - # Finally, shut down the client - await device_client.shutdown() - - -##################################################### -# EXECUTE MAIN - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/pnp/temp_controller_with_thermostats.py b/samples/pnp/temp_controller_with_thermostats.py deleted file mode 100644 index 79310d641..000000000 --- a/samples/pnp/temp_controller_with_thermostats.py +++ /dev/null @@ -1,424 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -import asyncio -import random -import logging - -from azure.iot.device.aio import IoTHubDeviceClient -from azure.iot.device.aio import ProvisioningDeviceClient -from azure.iot.device import MethodResponse -from datetime import timedelta, datetime -import pnp_helper - -logging.basicConfig(level=logging.ERROR) - -# the interfaces that are pulled in to implement the device. -# User has to know these values as these may change and user can -# choose to implement different interfaces. -thermostat_digital_twin_model_identifier = "dtmi:com:example:Thermostat;1" -device_info_digital_twin_model_identifier = "dtmi:azure:DeviceManagement:DeviceInformation;1" - -# The device "TemperatureController" that is getting implemented using the above interfaces. -# This id can change according to the company the user is from -# and the name user wants to call this Plug and Play device -model_id = "dtmi:com:example:TemperatureController;2" - -# the components inside this Plug and Play device. -# there can be multiple components from 1 interface -# component names according to interfaces following pascal case. -device_information_component_name = "deviceInformation" -thermostat_1_component_name = "thermostat1" -thermostat_2_component_name = "thermostat2" -serial_number = "some_serial_number" -##################################################### -# COMMAND HANDLERS : User will define these handlers -# depending on what commands the component defines - -##################################################### -# GLOBAL VARIABLES -THERMOSTAT_1 = None -THERMOSTAT_2 = None - - -class Thermostat(object): - def __init__(self, name, moving_win=10): - - self.moving_window = moving_win - self.records = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - self.index = 0 - - self.cur = 0 - self.max = 0 - self.min = 0 - self.avg = 0 - - self.name = name - - def record(self, current_temp): - self.cur = current_temp - self.records[self.index] = current_temp - self.max = self.calculate_max(current_temp) - self.min = self.calculate_min(current_temp) - self.avg = self.calculate_average() - - self.index = (self.index + 1) % self.moving_window - - def calculate_max(self, current_temp): - if not self.max: - return current_temp - elif current_temp > self.max: - return self.max - - def calculate_min(self, current_temp): - if not self.min: - return current_temp - elif current_temp < self.min: - return self.min - - def calculate_average(self): - return sum(self.records) / self.moving_window - - def create_report(self): - response_dict = {} - response_dict["maxTemp"] = self.max - response_dict["minTemp"] = self.min - response_dict["avgTemp"] = self.avg - response_dict["startTime"] = ( - (datetime.now() - timedelta(0, self.moving_window * 8)).astimezone().isoformat() - ) - response_dict["endTime"] = datetime.now().astimezone().isoformat() - return response_dict - - -async def reboot_handler(values): - if values: - print("Rebooting after delay of {delay} secs".format(delay=values)) - print("Done rebooting") - - -async def max_min_handler(values): - if values: - print( - "Will return the max, min and average temperature from the specified time {since} to the current time".format( - since=values - ) - ) - print("Done generating") - - -# END COMMAND HANDLERS -##################################################### - -##################################################### -# CREATE RESPONSES TO COMMANDS - - -def create_max_min_report_response(thermostat_name): - """ - An example function that can create a response to the "getMaxMinReport" command request the way the user wants it. - Most of the times response is created by a helper function which follows a generic pattern. - This should be only used when the user wants to give a detailed response back to the Hub. - :param values: The values that were received as part of the request. - """ - if "Thermostat;1" in thermostat_name and THERMOSTAT_1: - response_dict = THERMOSTAT_1.create_report() - elif THERMOSTAT_2: - response_dict = THERMOSTAT_2.create_report() - else: # This is done to pass certification. - response_dict = {} - response_dict["maxTemp"] = 0 - response_dict["minTemp"] = 0 - response_dict["avgTemp"] = 0 - response_dict["startTime"] = datetime.now().astimezone().isoformat() - response_dict["endTime"] = datetime.now().astimezone().isoformat() - - print(response_dict) - return response_dict - - -# END CREATE RESPONSES TO COMMANDS -##################################################### - -##################################################### -# TELEMETRY TASKS - - -async def send_telemetry_from_temp_controller(device_client, telemetry_msg, component_name=None): - msg = pnp_helper.create_telemetry(telemetry_msg, component_name) - await device_client.send_message(msg) - print("Sent message") - print(msg) - await asyncio.sleep(5) - - -##################################################### -# COMMAND TASKS - - -async def execute_command_listener( - device_client, - component_name=None, - method_name=None, - user_command_handler=None, - create_user_response_handler=None, -): - """ - Coroutine for executing listeners. These will listen for command requests. - They will take in a user provided handler and call the user provided handler - according to the command request received. - :param device_client: The device client - :param component_name: The name of the device like "sensor" - :param method_name: (optional) The specific method name to listen for. Eg could be "blink", "turnon" etc. - If not provided the listener will listen for all methods. - :param user_command_handler: (optional) The user provided handler that needs to be executed after receiving "command requests". - If not provided nothing will be executed on receiving command. - :param create_user_response_handler: (optional) The user provided handler that will create a response. - If not provided a generic response will be created. - :return: - """ - while True: - if component_name and method_name: - command_name = component_name + "*" + method_name - elif method_name: - command_name = method_name - else: - command_name = None - - command_request = await device_client.receive_method_request(command_name) - print("Command request received with payload") - values = command_request.payload - print(values) - - if user_command_handler: - await user_command_handler(values) - else: - print("No handler provided to execute") - - (response_status, response_payload) = pnp_helper.create_response_payload_with_status( - command_request, method_name, create_user_response=create_user_response_handler - ) - - command_response = MethodResponse.create_from_method_request( - command_request, response_status, response_payload - ) - - try: - await device_client.send_method_response(command_response) - except Exception: - print("responding to the {command} command failed".format(command=method_name)) - - -##################################################### -# PROPERTY TASKS - - -async def execute_property_listener(device_client): - while True: - patch = await device_client.receive_twin_desired_properties_patch() # blocking call - print(patch) - properties_dict = pnp_helper.create_reported_properties_from_desired(patch) - - await device_client.patch_twin_reported_properties(properties_dict) - - -##################################################### -# An # END KEYBOARD INPUT LISTENER to quit application - - -def stdin_listener(): - """ - Listener for quitting the sample - """ - while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - -# END KEYBOARD INPUT LISTENER -##################################################### - - -##################################################### -# MAIN STARTS -async def provision_device(provisioning_host, id_scope, registration_id, symmetric_key, model_id): - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=symmetric_key, - ) - - provisioning_device_client.provisioning_payload = {"modelId": model_id} - return await provisioning_device_client.register() - - -async def main(): - switch = os.getenv("IOTHUB_DEVICE_SECURITY_TYPE") - if switch == "DPS": - provisioning_host = ( - os.getenv("IOTHUB_DEVICE_DPS_ENDPOINT") - if os.getenv("IOTHUB_DEVICE_DPS_ENDPOINT") - else "global.azure-devices-provisioning.net" - ) - id_scope = os.getenv("IOTHUB_DEVICE_DPS_ID_SCOPE") - registration_id = os.getenv("IOTHUB_DEVICE_DPS_DEVICE_ID") - symmetric_key = os.getenv("IOTHUB_DEVICE_DPS_DEVICE_KEY") - - registration_result = await provision_device( - provisioning_host, id_scope, registration_id, symmetric_key, model_id - ) - - if registration_result.status == "assigned": - print("Device was assigned") - print(registration_result.registration_state.assigned_hub) - print(registration_result.registration_state.device_id) - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=symmetric_key, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - product_info=model_id, - ) - else: - raise RuntimeError( - "Could not provision device. Aborting Plug and Play device connection." - ) - - elif switch == "connectionString": - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - print("Connecting using Connection String " + conn_str) - device_client = IoTHubDeviceClient.create_from_connection_string( - conn_str, product_info=model_id - ) - else: - raise RuntimeError( - "At least one choice needs to be made for complete functioning of this sample." - ) - - # Connect the client. - await device_client.connect() - - ################################################ - # Update readable properties from various components - - properties_root = pnp_helper.create_reported_properties(serialNumber=serial_number) - properties_thermostat1 = pnp_helper.create_reported_properties( - thermostat_1_component_name, maxTempSinceLastReboot=98.34 - ) - properties_thermostat2 = pnp_helper.create_reported_properties( - thermostat_2_component_name, maxTempSinceLastReboot=48.92 - ) - properties_device_info = pnp_helper.create_reported_properties( - device_information_component_name, - swVersion="5.5", - manufacturer="Contoso Device Corporation", - model="Contoso 4762B-turbo", - osName="Mac Os", - processorArchitecture="x86-64", - processorManufacturer="Intel", - totalStorage=1024, - totalMemory=32, - ) - - property_updates = asyncio.gather( - device_client.patch_twin_reported_properties(properties_root), - device_client.patch_twin_reported_properties(properties_thermostat1), - device_client.patch_twin_reported_properties(properties_thermostat2), - device_client.patch_twin_reported_properties(properties_device_info), - ) - - ################################################ - # Get all the listeners running - print("Listening for command requests and property updates") - - global THERMOSTAT_1 - global THERMOSTAT_2 - THERMOSTAT_1 = Thermostat(thermostat_1_component_name, 10) - THERMOSTAT_2 = Thermostat(thermostat_2_component_name, 10) - - listeners = asyncio.gather( - execute_command_listener( - device_client, method_name="reboot", user_command_handler=reboot_handler - ), - execute_command_listener( - device_client, - thermostat_1_component_name, - method_name="getMaxMinReport", - user_command_handler=max_min_handler, - create_user_response_handler=create_max_min_report_response, - ), - execute_command_listener( - device_client, - thermostat_2_component_name, - method_name="getMaxMinReport", - user_command_handler=max_min_handler, - create_user_response_handler=create_max_min_report_response, - ), - execute_property_listener(device_client), - ) - - ################################################ - # Function to send telemetry every 8 seconds - - async def send_telemetry(): - print("Sending telemetry from various components") - - while True: - curr_temp_ext = random.randrange(10, 50) - THERMOSTAT_1.record(curr_temp_ext) - - temperature_msg1 = {"temperature": curr_temp_ext} - await send_telemetry_from_temp_controller( - device_client, temperature_msg1, thermostat_1_component_name - ) - - curr_temp_int = random.randrange(10, 50) # Current temperature in Celsius - THERMOSTAT_2.record(curr_temp_int) - - temperature_msg2 = {"temperature": curr_temp_int} - - await send_telemetry_from_temp_controller( - device_client, temperature_msg2, thermostat_2_component_name - ) - - workingset_msg3 = {"workingSet": random.randrange(1, 100)} - await send_telemetry_from_temp_controller(device_client, workingset_msg3) - - send_telemetry_task = asyncio.ensure_future(send_telemetry()) - - # Run the stdin listener in the event loop - loop = asyncio.get_running_loop() - user_finished = loop.run_in_executor(None, stdin_listener) - # # Wait for user to indicate they are done listening for method calls - await user_finished - - if not listeners.done(): - listeners.set_result("DONE") - - if not property_updates.done(): - property_updates.set_result("DONE") - - listeners.cancel() - property_updates.cancel() - - send_telemetry_task.cancel() - - # Finally, shut down the client - await device_client.shutdown() - - -##################################################### -# EXECUTE MAIN - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/provisioning_group_x509_client_certs.py b/samples/provisioning_group_x509_client_certs.py new file mode 100644 index 000000000..ceb167a00 --- /dev/null +++ b/samples/provisioning_group_x509_client_certs.py @@ -0,0 +1,73 @@ +import asyncio +import logging +from azure.iot.device import ( + ProvisioningSession, + MQTTConnectionDroppedError, + MQTTConnectionFailedError, +) +import os +import ssl + +id_scope = os.getenv("PROVISIONING_IDSCOPE") + +logging.basicConfig(level=logging.DEBUG) + + +def create_default_context(certfile, keyfile, password): + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + ssl_context.load_default_certs() + ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile, password=password) + return ssl_context + + +async def run_dps(registration_id, certfile, keyfile, password): + try: + ssl_context = create_default_context(certfile, keyfile, password) + async with ProvisioningSession( + registration_id=registration_id, + id_scope=id_scope, + ssl_context=ssl_context, + ) as session: + print("Connected") + result = await session.register() + print("Finished provisioning") + print(result) + + except MQTTConnectionDroppedError as me: + # Connection has been lost. + print("Dropped connection. Exiting") + raise Exception("Dropped connection") from me + except MQTTConnectionFailedError as mce: + # Connection failed to be established. + print("Could not connect. Exiting") + raise Exception("Could not connect. Exiting") from mce + + +async def main(): + print("Starting group provisioning sample") + print("Press Ctrl-C to exit") + + try: + # These are all fake file names and password and to be replaced as per scenario. + await asyncio.gather( + run_dps("devicemydomain1", "device_cert1.pem", "device_key1.pem", "devicepass"), + run_dps("devicemydomain2", "device_cert2.pem", "device_key2.pem", "devicepass"), + run_dps("devicemydomain3", "device_cert3.pem", "device_key3.pem", "devicepass"), + ) + except Exception as e: + print("Caught exception while trying to run dps") + print(e.__cause__) + finally: + print("Finishing sample") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting") diff --git a/samples/provisioning_symmetric_key.py b/samples/provisioning_symmetric_key.py new file mode 100644 index 000000000..e387478d4 --- /dev/null +++ b/samples/provisioning_symmetric_key.py @@ -0,0 +1,45 @@ +import asyncio +import logging +from azure.iot.device import ( + ProvisioningSession, + MQTTConnectionDroppedError, + MQTTConnectionFailedError, +) +import os + + +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") +symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") + + +logging.basicConfig(level=logging.DEBUG) + + +async def main(): + try: + async with ProvisioningSession( + registration_id=registration_id, + id_scope=id_scope, + shared_access_key=symmetric_key, + ) as session: + print("Connected") + result = await session.register(payload="optional registration payload") + print("Finished provisioning") + print(result) + + except MQTTConnectionDroppedError: + # Connection has been lost. + print("Dropped connection. Exiting") + except MQTTConnectionFailedError: + # Connection failed to be established. + print("Could not connect. Exiting") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting") diff --git a/samples/provisioning_x509_client_certs.py b/samples/provisioning_x509_client_certs.py new file mode 100644 index 000000000..de64f1adf --- /dev/null +++ b/samples/provisioning_x509_client_certs.py @@ -0,0 +1,56 @@ +import asyncio +import logging +from azure.iot.device import ( + ProvisioningSession, + MQTTConnectionDroppedError, + MQTTConnectionFailedError, +) +import os +import ssl + + +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") + +logging.basicConfig(level=logging.DEBUG) + + +async def main(): + try: + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + ssl_context.load_default_certs() + # These are all fake file names and password and to be replaced as per scenario. + ssl_context.load_cert_chain( + certfile="device_cert4.pem", + keyfile="device_key4.pem", + password="devicepass", + ) + + async with ProvisioningSession( + registration_id=registration_id, + id_scope=id_scope, + ssl_context=ssl_context, + ) as session: + print("Connected") + result = await session.register(payload="optional registration payload") + print("Finished provisioning") + print(result) + + except MQTTConnectionDroppedError: + # Connection has been lost. + print("Dropped connection. Exiting") + except MQTTConnectionFailedError: + # Connection failed to be established. + print("Could not connect. Exiting") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + # Exit application because user indicated they wish to exit. + # This will have cancelled `main()` implicitly. + print("User initiated exit. Exiting") diff --git a/samples/simple_send_message.py b/samples/simple_send_message.py deleted file mode 100644 index 5d7deda1d..000000000 --- a/samples/simple_send_message.py +++ /dev/null @@ -1,37 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import asyncio -from azure.iot.device.aio import IoTHubDeviceClient - - -async def main(): - # Fetch the connection string from an environment variable - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - # Create instance of the device client using the connection string - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # Connect the device client. - await device_client.connect() - - # Send a single message - print("Sending message...") - await device_client.send_message("This is a message that is being sent") - print("Message successfully sent!") - - # Finally, shut down the client - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() diff --git a/samples/sync-samples/README.md b/samples/sync-samples/README.md deleted file mode 100644 index d0a50619b..000000000 --- a/samples/sync-samples/README.md +++ /dev/null @@ -1,81 +0,0 @@ -# Synchronous API Scenario Samples for the Azure IoT Hub Device SDK - -This directory contains samples showing how to use the various features of Azure IoT Hub Device SDK with the Azure IoT Hub and Azure IoT Edge. - -## IoTHub Device Samples - -In order to use these samples, you **must** set your Device Connection String in the environment variable `IOTHUB_DEVICE_CONNECTION_STRING`. - -* [send_message.py](send_message.py) - Send multiple telemetry messages in parallel from a device to the Azure IoT Hub. - * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: - - ```Shell - bash az iot hub monitor-events --hub-name --output table``` - -* [recurring_telemetry.py](recurring_telemetry.py) - Send telemetry message every two seconds from a device to the Azure IoT Hub. - * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: - - ```Shell - bash az iot hub monitor-events --hub-name --output table``` - -* [receive_message.py](receive_message.py) - Receive Cloud-to-Device (C2D) messages sent from the Azure IoT Hub to a device. - * In order to send a C2D message, use the following Azure CLI command: - - ```Shell - az iot device c2d-message send --device-id --hub-name --data - ``` - -* [receive_direct_method.py](receive_direct_method.py) - Receive direct method requests on a device from the Azure IoT Hub and send responses back - * In order to invoke a direct method, use the following Azure CLI command: - - ```Shell - az iot hub invoke-device-method --device-id --hub-name --method-name - ``` - -* [receive_twin_desired_properties_patch](receive_twin_desired_properties_patch.py) - Receive an update patch of changes made to the device twin's desired properties - * In order to send a update patch to a device twin's desired properties, use the following Azure CLI command: - - ```Shell - az iot hub device-twin update --device-id --hub-name --set properties.desired.= - ``` - -* [update_twin_reported_properties](update_twin_reported_properties.py) - Send an update patch of changes to the device twin's reported properties - * You can see the changes reflected in your device twin by using the following Azure CLI command: - - ```Shell - az iot hub device-twin show --device-id --hub-name - ``` - -## IoT Edge Module Samples - -In order to use these samples, they **must** be run from inside an Edge container. - -* [receive_message_on_input.py](receive_message_on_input.py) - Receive messages sent to an Edge module on a specific module input. -* [send_message_to_output.py](send_message_to_output.py) - Send multiple messages in parallel from an Edge module to a specific output - -## DPS Samples - -### Individual - -In order to use these samples, you **must** have the following environment variables :- - -* PROVISIONING_HOST -* PROVISIONING_IDSCOPE -* PROVISIONING_REGISTRATION_ID - -There are 2 ways that your device can get registered to the provisioning service differing in authentication mechanisms. Depending on the mechanism used additional environment variables are needed for the samples:- - -* [provision_symmetric_key.py](provision_symmetric_key.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key, then send telemetry messages to IoTHub. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. -* [provision_symmetric_key_with_payload.py](provision_symmetric_key_with_payload.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key while supplying a custom payload, then send telemetry messages to IoTHub. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. -* [provision_x509.py](provision_x509.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key, then send a telemetry message to IoTHub. For this you must have the environment variable X509_CERT_FILE, X509_KEY_FILE, PASS_PHRASE. - - -#### Group Enrollment - -In order to use these samples, you **must** have the following environment variables :- - -* PROVISIONING_HOST -* PROVISIONING_IDSCOPE - -* [provision_symmetric_key_group.py](provision_symmetric_key_group.py) - Provision multiple devices to IoTHub by registering them to the Device Provisioning Service using derived symmetric keys, then send telemetry to IoTHub from these devices. For this you must have knowledge of the group symmetric key and must have the environment variables PROVISIONING_DEVICE_ID_1, PROVISIONING_DEVICE_ID_2, PROVISIONING_DEVICE_ID_3. - * NOTE : Group symmetric key must NEVER be stored and all the device keys must be computationally derived prior to using this sample. diff --git a/samples/sync-samples/get_twin.py b/samples/sync-samples/get_twin.py deleted file mode 100644 index 8f9c2fe20..000000000 --- a/samples/sync-samples/get_twin.py +++ /dev/null @@ -1,23 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -from azure.iot.device import IoTHubDeviceClient - - -conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") -device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - -# connect the client. -device_client.connect() - -# get the twin -twin = device_client.get_twin() -print("Twin document:") -print("{}".format(twin)) - -# Finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/provision_symmetric_key.py b/samples/sync-samples/provision_symmetric_key.py deleted file mode 100644 index 40c6a522a..000000000 --- a/samples/sync-samples/provision_symmetric_key.py +++ /dev/null @@ -1,61 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device import ProvisioningDeviceClient -import os -import time -from azure.iot.device import IoTHubDeviceClient, Message -import uuid - -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") -registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") -symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") - -provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=symmetric_key, -) - -registration_result = provisioning_device_client.register() -# The result can be directly printed to view the important details. -print(registration_result) - -# Individual attributes can be seen as well -print("The status was :-") -print(registration_result.status) -print("The etag is :-") -print(registration_result.registration_state.etag) - -if registration_result.status == "assigned": - print("Will send telemetry from the provisioned device") - # Create device client from the above result - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=symmetric_key, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - - # Connect the client. - device_client.connect() - - for i in range(1, 6): - print("sending message #" + str(i)) - device_client.send_message("test payload message " + str(i)) - time.sleep(1) - - for i in range(6, 11): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - device_client.send_message(msg) - time.sleep(1) - - # finally, disconnect - device_client.disconnect() -else: - print("Can not send telemetry from the provisioned device") diff --git a/samples/sync-samples/provision_symmetric_key_group.py b/samples/sync-samples/provision_symmetric_key_group.py deleted file mode 100644 index 8eb9ae1c7..000000000 --- a/samples/sync-samples/provision_symmetric_key_group.py +++ /dev/null @@ -1,113 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -import base64 -import hmac -import hashlib -from azure.iot.device import ProvisioningDeviceClient -from azure.iot.device import IoTHubDeviceClient -import time - -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") - -# These are the names of the devices that will eventually show up on the IoTHub -# Please make sure that there are no spaces in these device ids. -device_id_1 = os.getenv("PROVISIONING_DEVICE_ID_1") -device_id_2 = os.getenv("PROVISIONING_DEVICE_ID_2") -device_id_3 = os.getenv("PROVISIONING_DEVICE_ID_3") - - -# For computation of device keys -device_ids_to_keys = {} - -# Keep a dictionary for results -results = {} - -# NOTE : Only for illustration purposes. -# This is how a device key can be derived from the group symmetric key. -# This is just a helper function to show how it is done. -# Please don't directly store the master group key on the device. -# Follow the following method to compute the device key somewhere else. - - -def derive_device_key(device_id, group_symmetric_key): - """ - The unique device ID and the group master key should be encoded into "utf-8" - After this the encoded group master key must be used to compute an HMAC-SHA256 of the encoded registration ID. - Finally the result must be converted into Base64 format. - The device key is the "utf-8" decoding of the above result. - """ - message = device_id.encode("utf-8") - signing_key = base64.b64decode(group_symmetric_key.encode("utf-8")) - signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) - device_key_encoded = base64.b64encode(signed_hmac.digest()) - return device_key_encoded.decode("utf-8") - - -# derived_device_key has been computed already using the helper function somewhere else -# AND NOT on this sample. Do not use the direct master key on this sample to compute device key. -derived_device_key_1 = "some_value_already_computed" -derived_device_key_2 = "some_value_already_computed" -derived_device_key_3 = "some_value_already_computed" - - -device_ids_to_keys[device_id_1] = derived_device_key_1 -device_ids_to_keys[device_id_2] = derived_device_key_2 -device_ids_to_keys[device_id_3] = derived_device_key_3 - - -def register_device(registration_id): - - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=device_ids_to_keys[registration_id], - ) - - return provisioning_device_client.register() - - -for device_id in device_ids_to_keys: - registration_result = register_device(registration_id=device_id) - results[device_id] = registration_result - - -for device_id in device_ids_to_keys: - # The result can be directly printed to view the important details. - registration_result = results[device_id] - print(registration_result) - # Individual attributes can be seen as well - print("The status was :-") - print(registration_result.status) - print("The etag is :-") - print(registration_result.registration_state.etag) - print("\n") - if registration_result.status == "assigned": - print("Will send telemetry from the provisioned device with id {id}".format(id=device_id)) - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=device_ids_to_keys[device_id], - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - - # Connect the client. - device_client.connect() - - # Send 5 messages - for i in range(1, 6): - print("sending message #" + str(i)) - device_client.send_message("test payload message " + str(i)) - time.sleep(0.5) - - # finally, disconnect - device_client.disconnect() - - else: - print( - "Can not send telemetry from the provisioned device with id {id}".format(id=device_id) - ) diff --git a/samples/sync-samples/provision_symmetric_key_with_payload.py b/samples/sync-samples/provision_symmetric_key_with_payload.py deleted file mode 100644 index 4c553e213..000000000 --- a/samples/sync-samples/provision_symmetric_key_with_payload.py +++ /dev/null @@ -1,74 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -from azure.iot.device import ProvisioningDeviceClient -from azure.iot.device import IoTHubDeviceClient, Message -import uuid -import time - - -class Fruit(object): - def __init__(self, first_name, last_name, dict_of_stuff): - self.first_name = first_name - self.last_name = last_name - self.props = dict_of_stuff - - -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") -registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") -symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") - -provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=symmetric_key, -) - -properties = {"Type": "Apple", "Sweet": "True"} -fruit_a = Fruit("McIntosh", "Red", properties) - -provisioning_device_client.provisioning_payload = fruit_a -registration_result = provisioning_device_client.register() -# The result can be directly printed to view the important details. -print(registration_result) - -# Individual attributes can be seen as well -print("The status was :-") -print(registration_result.status) -print("The etag is :-") -print(registration_result.registration_state.etag) - - -if registration_result.status == "assigned": - print("Will send telemetry from the provisioned device") - # Create device client from the above result - device_client = IoTHubDeviceClient.create_from_symmetric_key( - symmetric_key=symmetric_key, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - - # Connect the client. - device_client.connect() - - for i in range(1, 6): - print("sending message #" + str(i)) - device_client.send_message("test payload message " + str(i)) - time.sleep(1) - - for i in range(6, 11): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - device_client.send_message(msg) - time.sleep(1) - - # finally, disconnect - device_client.disconnect() -else: - print("Can not send telemetry from the provisioned device") diff --git a/samples/sync-samples/provision_x509.py b/samples/sync-samples/provision_x509.py deleted file mode 100644 index 20b53b505..000000000 --- a/samples/sync-samples/provision_x509.py +++ /dev/null @@ -1,60 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -from azure.iot.device import ProvisioningDeviceClient, X509 -import time -from azure.iot.device import IoTHubDeviceClient, Message - - -provisioning_host = os.getenv("PROVISIONING_HOST") -id_scope = os.getenv("PROVISIONING_IDSCOPE") -registration_id = os.getenv("DPS_X509_REGISTRATION_ID") - -x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), -) - -provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - x509=x509, -) - -registration_result = provisioning_device_client.register() - -# The result can be directly printed to view the important details. -print(registration_result) - -if registration_result.status == "assigned": - print("Will send telemetry from the provisioned device") - # Create device client from the above result - device_client = IoTHubDeviceClient.create_from_x509_certificate( - x509=x509, - hostname=registration_result.registration_state.assigned_hub, - device_id=registration_result.registration_state.device_id, - ) - - # Connect the client. - device_client.connect() - - for i in range(1, 6): - print("sending message #" + str(i)) - device_client.send_message("test payload message " + str(i)) - time.sleep(1) - - for i in range(6, 11): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - device_client.send_message(msg) - time.sleep(1) - - # finally, disconnect - -else: - print("Can not send telemetry from the provisioned device") diff --git a/samples/sync-samples/receive_direct_method.py b/samples/sync-samples/receive_direct_method.py deleted file mode 100644 index 680369a8f..000000000 --- a/samples/sync-samples/receive_direct_method.py +++ /dev/null @@ -1,54 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -from azure.iot.device import IoTHubDeviceClient, MethodResponse - -# The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. -conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") -# The client object is used to interact with your Azure IoT hub. -device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - -# connect the client. -device_client.connect() - - -# Define behavior for handling methods -def method_request_handler(method_request): - # Determine how to respond to the method request based on the method name - if method_request.name == "method1": - payload = {"result": True, "data": "some data"} # set response payload - status = 200 # set return status code - print("executed method1") - elif method_request.name == "method2": - payload = {"result": True, "data": 1234} # set response payload - status = 200 # set return status code - print("executed method2") - else: - payload = {"result": False, "data": "unknown method"} # set response payload - status = 400 # set return status code - print("executed unknown method: " + method_request.name) - - # Send the response - method_response = MethodResponse.create_from_method_request(method_request, status, payload) - device_client.send_method_response(method_response) - - -# Set the method request handler on the client -device_client.on_method_request_received = method_request_handler - - -# Wait for user to indicate they are done listening for messages -while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/receive_message.py b/samples/sync-samples/receive_message.py deleted file mode 100644 index 0b115bb2e..000000000 --- a/samples/sync-samples/receive_message.py +++ /dev/null @@ -1,41 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -from azure.iot.device import IoTHubDeviceClient - -# The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. -conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") -# The client object is used to interact with your Azure IoT hub. -device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - -# connect the client. -device_client.connect() - - -# define behavior for receiving a message -def message_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - - -# set the message handler on the client -device_client.on_message_received = message_handler - - -# Wait for user to indicate they are done listening for messages -while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/receive_message_on_input.py b/samples/sync-samples/receive_message_on_input.py deleted file mode 100644 index 1e3a25c3a..000000000 --- a/samples/sync-samples/receive_message_on_input.py +++ /dev/null @@ -1,68 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -import threading -import signal -import time -from azure.iot.device import IoTHubModuleClient - - -# Event indicating client stop -stop_event = threading.Event() - - -def create_client(): - # Inputs/Outputs are only supported in the context of Azure IoT Edge and module client - # The module client object acts as an Azure IoT Edge module and interacts with an Azure IoT Edge hub - client = IoTHubModuleClient.create_from_edge_environment() - - # define behavior for receiving a message on inputs 1 and 2 - def message_handler(message): - if message.input_name == "input1": - print("Message received on INPUT 1") - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - elif message.input_name == "input2": - print("Message received on INPUT 2") - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - else: - print("message received on unknown input") - - # set the message handler on the client - client.on_message_received = message_handler - - return client - - -def main(): - # The client object is used to interact with your Azure IoT hub. - client = create_client() - - def module_termination_handler(signal, frame): - print("IoTHubClient sample stopped by Edge") - stop_event.set() - - # Attach a handler to do cleanup when module is terminated by Edge - signal.signal(signal.SIGTERM, module_termination_handler) - - try: - client.connect() - while not stop_event.is_set(): - time.sleep(100) - except Exception as e: - print("Unexpected error %s " % e) - raise - finally: - print("Shutting down client") - client.shutdown() - - -if __name__ == "__main__": - main() diff --git a/samples/sync-samples/receive_message_x509.py b/samples/sync-samples/receive_message_x509.py deleted file mode 100644 index 26dab6759..000000000 --- a/samples/sync-samples/receive_message_x509.py +++ /dev/null @@ -1,51 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -from azure.iot.device import IoTHubDeviceClient, X509 - -hostname = os.getenv("HOSTNAME") -# The device that has been created on the portal using X509 CA signing or Self signing capabilities -device_id = os.getenv("DEVICE_ID") - -x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), -) - -# The client object is used to interact with your Azure IoT hub. -device_client = IoTHubDeviceClient.create_from_x509_certificate( - hostname=hostname, device_id=device_id, x509=x509 -) - - -# connect the client. -device_client.connect() - - -# define behavior for receiving a message -def message_received_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - - -# Set the message received handler on the client -device_client.on_message_received = message_received_handler - - -# Wait for user to indicate they are done listening for messages -while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/receive_twin_desired_properties_patch.py b/samples/sync-samples/receive_twin_desired_properties_patch.py deleted file mode 100644 index d4c0e3321..000000000 --- a/samples/sync-samples/receive_twin_desired_properties_patch.py +++ /dev/null @@ -1,35 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -from azure.iot.device import IoTHubDeviceClient - -conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") -device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - -# connect the client. -device_client.connect() - - -# define behavior for receiving a twin patch -def twin_patch_handler(patch): - print("the data in the desired properties patch was: {}".format(patch)) - - -# set the twin patch handler on the client -device_client.on_twin_desired_properties_patch_received = twin_patch_handler - - -# Wait for user to indicate they are done listening for messages -while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/recurring_telemetry.py b/samples/sync-samples/recurring_telemetry.py deleted file mode 100644 index 50cee449c..000000000 --- a/samples/sync-samples/recurring_telemetry.py +++ /dev/null @@ -1,48 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import os -import uuid -import time -from azure.iot.device import IoTHubDeviceClient -from azure.iot.device import Message - - -def main(): - # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - # The client object is used to interact with your Azure IoT hub. - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - print("IoTHub Device Client Recurring Telemetry Sample") - print("Press Ctrl+C to exit") - try: - # Connect the client. - device_client.connect() - - # Send recurring telemetry - i = 0 - while True: - i += 1 - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - print("sending message #" + str(i)) - device_client.send_message(msg) - time.sleep(2) - except KeyboardInterrupt: - print("User initiated exit") - except Exception: - print("Unexpected exception!") - raise - finally: - device_client.shutdown() - - -if __name__ == "__main__": - main() diff --git a/samples/sync-samples/send_message.py b/samples/sync-samples/send_message.py deleted file mode 100644 index d124ac95b..000000000 --- a/samples/sync-samples/send_message.py +++ /dev/null @@ -1,72 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import time -import uuid -from azure.iot.device import IoTHubDeviceClient, Message - -# The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. -conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") -# The client object is used to interact with your Azure IoT hub. -device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - -# Connect the client. -device_client.connect() - -# send 2 messages with 2 system properties & 1 custom property with a 1 second pause between each message -for i in range(1, 3): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send 2 messages with only custom property with a 1 second pause between each message -for i in range(3, 5): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send 2 messages with only system properties with a 1 second pause between each message -for i in range(5, 7): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send 2 messages with 1 system property and 1 custom property with a 1 second pause between each message -for i in range(7, 9): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send only string messages -for i in range(9, 11): - print("sending message #" + str(i)) - device_client.send_message("test payload message " + str(i)) - time.sleep(1) - - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/send_message_to_output.py b/samples/sync-samples/send_message_to_output.py deleted file mode 100644 index 08b05347c..000000000 --- a/samples/sync-samples/send_message_to_output.py +++ /dev/null @@ -1,38 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import time -import uuid -from azure.iot.device import IoTHubModuleClient, Message - -# Inputs/Outputs are only supported in the context of Azure IoT Edge and module client -# The module client object acts as an Azure IoT Edge module and interacts with an Azure IoT Edge hub -module_client = IoTHubModuleClient.create_from_edge_environment() - -# Connect the client. -module_client.connect() - -# send 5 messages with a 1 second pause between each message -for i in range(1, 6): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - module_client.send_message_to_output(msg, "twister") - time.sleep(1) - -# send only string messages -for i in range(6, 11): - print("sending message #" + str(i)) - module_client.send_message_to_output("test payload message " + str(i), "tracking") - time.sleep(1) - - -# finally, shut down the client -module_client.shutdown() diff --git a/samples/sync-samples/send_message_via_module_x509.py b/samples/sync-samples/send_message_via_module_x509.py deleted file mode 100644 index 65d25c952..000000000 --- a/samples/sync-samples/send_message_via_module_x509.py +++ /dev/null @@ -1,54 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import time -import uuid -from azure.iot.device import IoTHubModuleClient, Message, X509 - -hostname = os.getenv("HOSTNAME") - -# The device having a certain module that has been created on the portal -# using X509 CA signing or Self signing capabilities -# The \ should be the common name of the certificate - -device_id = os.getenv("DEVICE_ID") -module_id = os.getenv("MODULE_ID") - -x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), -) - -module_client = IoTHubModuleClient.create_from_x509_certificate( - hostname=hostname, x509=x509, device_id=device_id, module_id=module_id -) - -module_client.connect() - - -# send 5 messages with a 1 second pause between each message -for i in range(1, 6): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - module_client.send_message(msg) - time.sleep(1) - -# send only string messages -for i in range(6, 11): - print("sending message #" + str(i)) - module_client.send_message("test payload message " + str(i)) - time.sleep(1) - - -# finally, shut down the client -module_client.shutdown() diff --git a/samples/sync-samples/send_message_via_proxy.py b/samples/sync-samples/send_message_via_proxy.py deleted file mode 100644 index ed5f2fe32..000000000 --- a/samples/sync-samples/send_message_via_proxy.py +++ /dev/null @@ -1,80 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import time -import uuid -from azure.iot.device import IoTHubDeviceClient, Message, ProxyOptions -import logging - -logging.basicConfig(level=logging.DEBUG) - -# The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. -conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - -# Create proxy options when trying to send via proxy -proxy_opts = ProxyOptions(proxy_type="HTTP", proxy_addr="127.0.0.1", proxy_port=8888) # localhost -# The client object is used to interact with your Azure IoT hub. -device_client = IoTHubDeviceClient.create_from_connection_string( - conn_str, websockets=True, proxy_options=proxy_opts -) - -# Connect the client. -device_client.connect() - -# send 2 messages with 2 system properties & 1 custom property with a 1 second pause between each message -for i in range(1, 3): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send 2 messages with only custom property with a 1 second pause between each message -for i in range(3, 5): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send 2 messages with only system properties with a 1 second pause between each message -for i in range(5, 7): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send 2 messages with 1 system property and 1 custom property with a 1 second pause between each message -for i in range(7, 9): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send only string messages -for i in range(9, 11): - print("sending message #" + str(i)) - device_client.send_message("test payload message " + str(i)) - time.sleep(1) - - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/send_message_x509.py b/samples/sync-samples/send_message_x509.py deleted file mode 100644 index 4115e1a09..000000000 --- a/samples/sync-samples/send_message_x509.py +++ /dev/null @@ -1,56 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import time -import uuid -from azure.iot.device import IoTHubDeviceClient, Message, X509 - -# The connection string for a device should never be stored in code. -# For the sake of simplicity we are creating the X509 connection string -# containing Hostname and Device Id in the following format: -# "HostName=;DeviceId=;x509=true" - -hostname = os.getenv("HOSTNAME") - -# The device that has been created on the portal using X509 CA signing or Self signing capabilities -device_id = os.getenv("DEVICE_ID") - -x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("X509_PASS_PHRASE"), -) - -# The client object is used to interact with your Azure IoT hub. -device_client = IoTHubDeviceClient.create_from_x509_certificate( - hostname=hostname, device_id=device_id, x509=x509 -) - -# Connect the client. -device_client.connect() - -# send 5 messages with a 1 second pause between each message -for i in range(1, 6): - print("sending message #" + str(i)) - msg = Message("test wind speed " + str(i)) - msg.message_id = uuid.uuid4() - msg.correlation_id = "correlation-1234" - msg.custom_properties["tornado-warning"] = "yes" - msg.content_encoding = "utf-8" - msg.content_type = "application/json" - device_client.send_message(msg) - time.sleep(1) - -# send only string messages -for i in range(6, 11): - print("sending message #" + str(i)) - device_client.send_message("test payload message " + str(i)) - time.sleep(1) - - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/update_twin_reported_properties.py b/samples/sync-samples/update_twin_reported_properties.py deleted file mode 100644 index 262cf410b..000000000 --- a/samples/sync-samples/update_twin_reported_properties.py +++ /dev/null @@ -1,23 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import os -import random -from azure.iot.device import IoTHubDeviceClient - -conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") -device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - -# connect the client. -device_client.connect() - -# send new reported properties -reported_properties = {"temperature": random.randint(320, 800) / 10} -print("Setting reported temperature to {}".format(reported_properties["temperature"])) -device_client.patch_twin_reported_properties(reported_properties) - -# finally, shut down the client -device_client.shutdown() diff --git a/samples/sync-samples/use_custom_sastoken.py b/samples/sync-samples/use_custom_sastoken.py deleted file mode 100644 index fb7d04549..000000000 --- a/samples/sync-samples/use_custom_sastoken.py +++ /dev/null @@ -1,57 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from azure.iot.device import IoTHubDeviceClient - - -# NOTE: This code needs to be completed in order to work. -# Fill out the get_new_sastoken() method to return a NEW custom sastoken from your solution. -# It must return a unique value each time it is called. -def get_new_sastoken(): - pass - - -# The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. -sastoken = get_new_sastoken() -# The client object is used to interact with your Azure IoT hub. -device_client = IoTHubDeviceClient.create_from_sastoken(sastoken) - - -# connect the client. -device_client.connect() - - -# define behavior for receiving a message -def message_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - - -# define behavior for updating sastoken -def sastoken_update_handler(): - print("Updating SAS Token...") - sastoken = get_new_sastoken() - device_client.update_sastoken(sastoken) - print("SAS Token updated") - - -# set the message handler on the client -device_client.on_message_received = message_handler -device_client.on_new_sastoken_required = sastoken_update_handler - - -# Wait for user to indicate they are done listening for messages -while True: - selection = input("Press Q to quit\n") - if selection == "Q" or selection == "q": - print("Quitting...") - break - - -# finally, shut down the client -device_client.shutdown() diff --git a/scripts/configure-virtual-environments.sh b/scripts/configure-virtual-environments.sh index 54c510269..c484563c9 100755 --- a/scripts/configure-virtual-environments.sh +++ b/scripts/configure-virtual-environments.sh @@ -6,7 +6,7 @@ script_dir=$(cd "$(dirname "$0")" && pwd) -export RUNTIMES_TO_INSTALL="3.6.6 3.7.1 3.8.10 3.9.9 3.10.2" +export RUNTIMES_TO_INSTALL="3.7.1 3.8.10 3.9.9 3.10.2" echo "This script will do the following:" echo "1. Use apt to install pre-requisites for pyenv" diff --git a/scripts/edge_setup/deploy/_deployment-helpers.sh b/scripts/edge_setup/deploy/_deployment-helpers.sh new file mode 100644 index 000000000..b3a2e2be1 --- /dev/null +++ b/scripts/edge_setup/deploy/_deployment-helpers.sh @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +export PATH_AGENT_PROPS=".modulesContent[\"\$edgeAgent\"][\"properties.desired\"]" +export PATH_SYSTEM_MODULES="${PATH_AGENT_PROPS}.systemModules" +export PATH_MODULES="${PATH_AGENT_PROPS}.modules" +export PATH_REGISTRY_CREDENTIALS="${PATH_AGENT_PROPS}.runtime.settings.registryCredentials" +export PATH_HUB_PROPS=".modulesContent[\"\$edgeHub\"][\"properties.desired\"]" +export PATH_ROUTES="${PATH_HUB_PROPS}.routes" diff --git a/scripts/edge_setup/deploy/add-module-to-deployment.sh b/scripts/edge_setup/deploy/add-module-to-deployment.sh new file mode 100755 index 000000000..f98ac9e3b --- /dev/null +++ b/scripts/edge_setup/deploy/add-module-to-deployment.sh @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. +set -o pipefail + +script_dir=$(cd "$(dirname "$0")" && pwd) +source ${script_dir}/_deployment-helpers.sh + +HUB_NAME=$1 +DEVICE_ID=$2 +MODULE_NAME=$3 +MODULE_IMAGE_NAME=$4 +CREATE_OPTIONS=$5 + +if [ "${CREATE_OPTIONS}" == "" ]; then + CREATE_OPTIONS="\"{}\"" +fi + + +if [ "${HUB_NAME}" == "" ] || [ "${DEVICE_ID}" == "" ] || [ "${MODULE_NAME}" == "" ] || [ "${MODULE_IMAGE_NAME}" == "" ]; then + echo Usage: $0 hubName deviceId moduleName moduleImageName [createOptionsJsonString] + echo hubName is without '.azure-devices.net' suffix + exit 1 +fi + +echo Creating manifest json with module ${MODULE_NAME} +TEMPFILE=$(mktemp) + +# +# JSON with required module JSON +# +read -d '' EMPTY_MODULE_JSON << EOF +{ + "version": "1.0", + "type": "docker", + "status": "running", + "restartPolicy": "always", + "settings": { + "image": "TODO", + "createOptions": "{}" + } +} +EOF + +BASE=$(az iot edge export-modules --device-id ${DEVICE_ID} --hub-name ${HUB_NAME} --query content) +[ $? -eq 0 ] || { echo "az iot edge export-modules failed"; exit 1; } + +echo ${BASE} | jq . \ + | jq "${PATH_MODULES}.${MODULE_NAME} = ${EMPTY_MODULE_JSON}" \ + | jq "${PATH_MODULES}.${MODULE_NAME}.settings.image = \"${MODULE_IMAGE_NAME}\"" \ + | jq "${PATH_MODULES}.${MODULE_NAME}.settings.createOptions = ${CREATE_OPTIONS}" \ + > ${TEMPFILE} +[ $? -eq 0 ] || { "jq failed"; exit 1; } + +echo Applying manifest json with module ${MODULE_NAME} +az iot edge set-modules --device-id ${DEVICE_ID} --hub-name ${HUB_NAME} --content ${TEMPFILE} > /dev/null +[ $? -eq 0 ] || { echo "az iot edge set-modules failed"; exit 1; } + +rm ${TEMPFILE} || true diff --git a/scripts/edge_setup/deploy/add-registry-to-deployment.sh b/scripts/edge_setup/deploy/add-registry-to-deployment.sh new file mode 100755 index 000000000..9d864c81b --- /dev/null +++ b/scripts/edge_setup/deploy/add-registry-to-deployment.sh @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) +source ${script_dir}/_deployment-helpers.sh + +HUB_NAME=$1 +DEVICE_ID=$2 + +if [ "${HUB_NAME}" == "" ] || [ "${DEVICE_ID}" == "" ]; then + echo Usage: $0 hubName deviceId + echo hubName is without '.azure-devices.net' suffix + exit 1 +fi + +if [ "${IOTHUB_E2E_REPO_USER}" == "" ] || [ "${IOTHUB_E2E_REPO_ADDRESS}" == "" ] || [ "${IOTHUB_E2E_REPO_PASSWORD}" == "" ]; then + echo "No private repostiry specified" + exit 0 +fi + +echo Creating manifest json with private registry +TEMPFILE=$(mktemp) + +BASE=$(az iot edge export-modules --device-id ${DEVICE_ID} --hub-name ${HUB_NAME} --query content) +[ $? -eq 0 ] || { echo "az iot edge export-modules failed"; exit 1; } + +# +# JSON for container retistry credentials +# +read -d '' REGISTRY_BLOCK << EOF +{ + ${IOTHUB_E2E_REPO_USER}: { + address: \"${IOTHUB_E2E_REPO_ADDRESS}\", + username: \"${IOTHUB_E2E_REPO_USER}\", + password: \"${IOTHUB_E2E_REPO_PASSWORD}\" + } +} +EOF + +echo ${BASE} | jq . - \ + | jq "${PATH_REGISTRY_CREDENTIALS} = ${REGISTRY_BLOCK}" \ + > ${TEMPFILE} +[ $? -eq 0 ] || { "jq failed"; exit 1; } + +echo Applying manifest json with private registry +az iot edge set-modules --device-id ${DEVICE_ID} --hub-name ${HUB_NAME} --content ${TEMPFILE} &> /dev/null +[ $? -eq 0 ] || { echo "az iot edge set-modules failed"; exit 1; } + +rm ${TEMPFILE} || true diff --git a/scripts/edge_setup/deploy/add-routing-rules-to-deployment.sh b/scripts/edge_setup/deploy/add-routing-rules-to-deployment.sh new file mode 100755 index 000000000..020fefb2a --- /dev/null +++ b/scripts/edge_setup/deploy/add-routing-rules-to-deployment.sh @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. +set -o pipefail + +script_dir=$(cd "$(dirname "$0")" && pwd) +source ${script_dir}/_deployment-helpers.sh + +HUB_NAME=$1 +DEVICE_ID=$2 +ROUTING_RULES=$3 + +if [ "${HUB_NAME}" == "" ] || [ "${DEVICE_ID}" == "" ] || [ "${ROUTING_RULES}" == "" ]; then + echo Usage: $0 hubName deviceId routingRules + echo hubName is without '.azure-devices.net' suffix + exit 1 +fi + +echo Creating manifest json with routing rules +TEMPFILE=$(mktemp) + +BASE=$(az iot edge export-modules --device-id ${DEVICE_ID} --hub-name ${HUB_NAME} --query content) +[ $? -eq 0 ] || { echo "az iot edge export-modules failed"; exit 1; } + +echo ${BASE} | jq . \ + | jq "${PATH_ROUTES} = ${ROUTING_RULES}" \ + | jq "${PATH_ROUTES} |= fromjson" \ + > ${TEMPFILE} +[ $? -eq 0 ] || { "jq failed"; exit 1; } + +echo Applying manifest json with routing rules +az iot edge set-modules --device-id ${DEVICE_ID} --hub-name ${HUB_NAME} --content ${TEMPFILE} > /dev/null +[ $? -eq 0 ] || { echo "az iot edge set-modules failed"; exit 1; } + +rm ${TEMPFILE} || true diff --git a/scripts/edge_setup/deploy/create-edge-device.sh b/scripts/edge_setup/deploy/create-edge-device.sh new file mode 100755 index 000000000..607f3376b --- /dev/null +++ b/scripts/edge_setup/deploy/create-edge-device.sh @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) + +HUB_NAME=$1 +DEVICE_ID=$2 + +if [ "${HUB_NAME}" == "" ] || [ "${DEVICE_ID}" == "" ]; then + echo Usage: $0 hubName deviceId + echo hubName is without '.azure-devices.net' suffix + exit 1 +fi + +echo Creating device ${DEVICE_ID} on hub ${HUB_NAME} +TEMPFILE=$(mktemp) +az iot hub device-identity create -n ${HUB_NAME} --device-id ${DEVICE_ID} --edge-enabled &> ${TEMPFILE} +if [ $? -ne 0 ]; then + echo "az iot hub device-identity create failed" + cat ${TEMPFILE} + rm ${TEMPFILE} + exit 1; +fi + +echo Getting connection string for ${DEVICE_ID} on ${HUB_NAME} +CS=$(az iot hub device-identity connection-string show -d ${DEVICE_ID} -n ${HUB_NAME} --output tsv --query "connectionString") +[ $? -eq 0 ] || { echo "az iot hub device-identity connection-string show failed"; exit 1; } + +echo Setting IoTHub configuration +sudo -E iotedge config mp --force --connection-string ${CS} +[ $? -eq 0 ] || { echo "iotedge config mp failed"; exit 1; } + +sudo iotedge config apply +[ $? -eq 0 ] || { echo "iotedge config apply failed"; exit 1; } + diff --git a/scripts/edge_setup/deploy/deploy-edge-modules.sh b/scripts/edge_setup/deploy/deploy-edge-modules.sh new file mode 100755 index 000000000..e31c09aa7 --- /dev/null +++ b/scripts/edge_setup/deploy/deploy-edge-modules.sh @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) + +HUB_NAME=$1 +DEVICE_ID=$2 +TEST_IMAGE_NAME=$3 +ECHO_IMAGE_NAME=$4 + +if [ "${HUB_NAME}" == "" ] || [ "${DEVICE_ID}" == "" ] || [ "${TEST_IMAGE_NAME}" == "" ] || [ "${ECHO_IMAGE_NAME}" == "" ]; then + echo Usage: $0 hubName deviceId testImageName echoImageName + echo hubName is without '.azure-devices.net' suffix + exit 1 +fi + +# +# JSON for test module createOptions +# +read -d '' TEST_MOD_CREATE_OPTIONS << EOF +{ + "HostConfig": { + "Binds": [ + "/home/bertk/projects/v3:/sdk" + ] + }, + "Entrypoint": [ + "python3", + "-uc", + "import time; print('waiting'); time.sleep(3600);" + ] +} +EOF +TEST_MOD_CREATE_OPTIONS=$(echo ${TEST_MOD_CREATE_OPTIONS} | jq ". |= tojson") +[ $? -eq 0 ] || { "jq failed"; exit 1; } + +# +# JSON for echo module createOptions +# +read -d '' ECHO_MOD_CREATE_OPTIONS << EOF +{ + "Entrypoint": [ + "python3", + "-uc", + "import time; print('waiting'); time.sleep(3600);" + ] +} +EOF +ECHO_MOD_CREATE_OPTIONS=$(echo ${ECHO_MOD_CREATE_OPTIONS} | jq ". |= tojson") +[ $? -eq 0 ] || { "jq failed"; exit 1; } + +# +# JSON for routing rules +# +ROUTING_RULES=$( +jq ". |= tojson" < ${TEMPFILE} +[ $? -eq 0 ] || { "jq failed"; exit 1; } + +echo Applying base manifest json +az iot edge set-modules --device-id ${DEVICE_ID} --hub-name ${HUB_NAME} --content ${TEMPFILE} > /dev/null +[ $? -eq 0 ] || { echo "az iot edge set-modules failed"; exit 1; } + +rm ${TEMPFILE} || true diff --git a/scripts/edge_setup/docker-build/build-echomod-container.sh b/scripts/edge_setup/docker-build/build-echomod-container.sh new file mode 100755 index 000000000..b2e3a49d8 --- /dev/null +++ b/scripts/edge_setup/docker-build/build-echomod-container.sh @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. +script_dir=$(cd "$(dirname "$0")" && pwd) + +IMAGE_NAME=$1 + +if [ "${IMAGE_NAME}" == "" ]; then + echo Usage: $0 imageName + echo eg: $0 localhost:5000/echomod:latest + exit 1 +fi + +docker pull ${IMAGE_NAME} + +if [ $? -eq 0 ]; then + echo "${IMAGE_NAME} already exists. Skipping build" +else + cd ${script_dir}/echoMod + + docker build -t ${IMAGE_NAME} . + [ $? -eq 0 ] || { echo "docker build failed"; exit 1; } + + docker push ${IMAGE_NAME} + [ $? -eq 0 ] || { echo "docker push failed"; exit 1; } +fi + diff --git a/scripts/edge_setup/docker-build/build-test-containers.sh b/scripts/edge_setup/docker-build/build-test-containers.sh new file mode 100755 index 000000000..402e7bf18 --- /dev/null +++ b/scripts/edge_setup/docker-build/build-test-containers.sh @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. +script_dir=$(cd "$(dirname "$0")" && pwd) +root_dir=$(cd "${script_dir}/../../.." && pwd) + +DOCKERFILE_NAME=$1 +IMAGE_NAME=$2 + +if [ "${DOCKERFILE_NAME}" == "" ] || [ "${IMAGE_NAME}" = "" ]; then + echo Usage: $0 dockerfileName imageName + echo eg. $f Dockerfile.py310 localhost:5000/python-e2e-py310:latest + exit 1 +fi + +cd ${script_dir}/dockerfiles + +docker build -t ${IMAGE_NAME} ${root_dir} -f ${DOCKERFILE_NAME} +[ $? -eq 0 ] || { echo "docker build failed"; exit 1; } + +docker push ${IMAGE_NAME} +[ $? -eq 0 ] || { echo "docker push failed"; exit 1; } + diff --git a/scripts/edge_setup/docker-build/dockerfiles/Dockerfile.py310 b/scripts/edge_setup/docker-build/dockerfiles/Dockerfile.py310 new file mode 100644 index 000000000..e271eed3b --- /dev/null +++ b/scripts/edge_setup/docker-build/dockerfiles/Dockerfile.py310 @@ -0,0 +1,13 @@ +from mcr.microsoft.com/mirror/docker/library/python:3.10-slim-buster + +RUN apt update \ + && apt install -y \ + iptables \ + && apt clean + +WORKDIR /sdk +COPY requirements_test.txt . +RUN pip install -r requirements_test.txt +COPY . . +RUN python ./scripts/env_setup.py --no_dev + diff --git a/scripts/edge_setup/docker-build/echoMod/.dockerignore b/scripts/edge_setup/docker-build/echoMod/.dockerignore new file mode 100644 index 000000000..15813be9f --- /dev/null +++ b/scripts/edge_setup/docker-build/echoMod/.dockerignore @@ -0,0 +1,2 @@ +package-lock.json +node_modules/ diff --git a/scripts/edge_setup/docker-build/echoMod/.gitignore b/scripts/edge_setup/docker-build/echoMod/.gitignore new file mode 100644 index 000000000..15813be9f --- /dev/null +++ b/scripts/edge_setup/docker-build/echoMod/.gitignore @@ -0,0 +1,2 @@ +package-lock.json +node_modules/ diff --git a/scripts/edge_setup/docker-build/echoMod/Dockerfile b/scripts/edge_setup/docker-build/echoMod/Dockerfile new file mode 100644 index 000000000..8bbbf309d --- /dev/null +++ b/scripts/edge_setup/docker-build/echoMod/Dockerfile @@ -0,0 +1,6 @@ +FROM mcr.microsoft.com/mirror/docker/library/node:16 +env DEBUG=rhea*,azure* +WORKDIR /sdk +COPY . /sdk +RUN npm install +ENTRYPOINT ["/usr/local/bin/node", "/sdk/app.js"] diff --git a/scripts/edge_setup/docker-build/echoMod/app.js b/scripts/edge_setup/docker-build/echoMod/app.js new file mode 100644 index 000000000..7a2e133ae --- /dev/null +++ b/scripts/edge_setup/docker-build/echoMod/app.js @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +"use strict"; + +const Protocol = require("azure-iot-device-mqtt").Mqtt; +const ModuleClient = require("azure-iot-device").ModuleClient; +const Message = require("azure-iot-device").Message; + +ModuleClient.fromEnvironment(Protocol, (err, client) => { + if (err) { + console.error(`Could not create client: {err}`); + process.exit(-1); + } else { + console.log("got client"); + + client.on("error", (err) => { + console.error(err.message); + }); + + client.open((err) => { + if (err) { + console.error(`Could not connect: {err}`); + process.exit(-1); + } else { + console.log("Client connected"); + + // Act on input messages to the module. + client.on("inputMessage", (inputName, msg) => { + if (inputName === "input1") { + client.sendOutputEvent("output2", msg, (err) => { + if (err) { + console.log(`sendOutputEvent failed {err}`); + } + }); + } else { + console.log(`unexpected input: {inputName}`); + } + }); + } + }); + } +}); diff --git a/scripts/edge_setup/docker-build/echoMod/package.json b/scripts/edge_setup/docker-build/echoMod/package.json new file mode 100644 index 000000000..0ad55ff02 --- /dev/null +++ b/scripts/edge_setup/docker-build/echoMod/package.json @@ -0,0 +1,12 @@ +{ + "dependencies": { + "azure-iot-device": "^1.18.2", + "azure-iot-device-mqtt": "^1.16.2" + }, + "scripts": { + "prettier": "prettier -w *.js package.json" + }, + "devDependencies": { + "prettier": "^2.8.8" + } +} diff --git a/scripts/edge_setup/install/install-azure-cli.sh b/scripts/edge_setup/install/install-azure-cli.sh new file mode 100755 index 000000000..98246b2b0 --- /dev/null +++ b/scripts/edge_setup/install/install-azure-cli.sh @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) +set -o pipefail + +echo "Checking for azure CLI" +which az +if [ $? -eq 0 ]; then + echo "Azure CLI installed" +else + echo "Installing Azure CLI" + + curl -L https://aka.ms/InstallAzureCli | bash + [ $? -eq 0 ] || { echo "install-microsoft-apt-repo failed"; exit 1; } +fi + +az --version +[ $? -eq 0 ] || { echo "az --version failed"; exit 1; } + +az extension add --name azure-iot +[ $? -eq 0 ] || { echo "az extension add failed"; exit 1; } + diff --git a/scripts/edge_setup/install/install-iotedge.sh b/scripts/edge_setup/install/install-iotedge.sh new file mode 100755 index 000000000..7d0c734e8 --- /dev/null +++ b/scripts/edge_setup/install/install-iotedge.sh @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) + +$script_dir/install-microsoft-apt-repo.sh +[ $? -eq 0 ] || { echo "install-microsoft-apt-repo failed"; exit 1; } + +# install iotedge +sudo apt-get install -y aziot-edge +[ $? -eq 0 ] || { echo "apt-get install aziot-edge failed"; exit 1; } + + diff --git a/scripts/edge_setup/install/install-microsoft-apt-repo.sh b/scripts/edge_setup/install/install-microsoft-apt-repo.sh new file mode 100755 index 000000000..06c36901e --- /dev/null +++ b/scripts/edge_setup/install/install-microsoft-apt-repo.sh @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) +source /etc/os-release + + +echo "Checking for Microsoft APT repo registration" +if [ -f /etc/apt/sources.list.d/microsoft-prod.list ]; then + echo "Microsoft APT repo already registered. Done." + exit 0 +fi + +# Download the Microsoft repository GPG keys +case $ID in + linuxmint) + if [[ $VERSION_ID == "19.3" ]]; + then + os_platform="ubuntu/18.04/multiarch" + fi + ;; + ubuntu) + if [[ $VERSION_ID == "18.04" ]]; + then + os_platform="$ID/$VERSION_ID/multiarch" + else + os_platform="$ID/$VERSION_ID" + fi + ;; + + raspbian) + if [ "$VERSION_CODENAME" == "bullseye" ] || [ "$VERSION_ID" == "11" ]; + then + os_platform="$ID_LIKE/11" + else + os_platform="$ID_LIKE/stretch/multiarch" + fi + ;; +esac + +if [ "${os_platform}" == "" ]; then + echo "ERROR: This script only works on Ubunto and Raspbian distros" + exit 1 +fi + +curl https://packages.microsoft.com/config/${os_platform}/prod.list > ./microsoft-prod.list +[ $? -eq 0 ] || { echo "curl failed"; exit 1; } + +# Register the Microsoft repository GPG keys +sudo cp ./microsoft-prod.list /etc/apt/sources.list.d/ +[ $? -eq 0 ] || { echo "sudo cp microsoft-prod.list failed"; exit 1; } + +rm microsoft-prod.list + +curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > microsoft.gpg +[ $? -eq 0 ] || { echo "curl microsoft.asc failed"; exit 1; } + +sudo cp ./microsoft.gpg /etc/apt/trusted.gpg.d/ +[ $? -eq 0 ] || { echo "cp microsoft.gpg failed"; exit 1; } + +rm microsoft.gpg + +# Update the list of products +sudo apt-get update +[ $? -eq 0 ] || { echo "apt update failed"; exit 1; } + +echo "Microsoft APT repo successfully registered" + diff --git a/scripts/edge_setup/install/install-moby.sh b/scripts/edge_setup/install/install-moby.sh new file mode 100755 index 000000000..2c1f1126f --- /dev/null +++ b/scripts/edge_setup/install/install-moby.sh @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) + +echo "Checking for moby install" +unset moby_cli_installed +dpkg -s moby-cli | grep -q "install ok installed" +if [ $? -eq 0 ]; then moby_cli_installed=true; fi + +unset moby_engine_installed +dpkg -s moby-engine | grep -q "install ok installed" +if [ $? -eq 0 ]; then moby_engine_installed=true; fi + +need_moby=true +if [ $moby_engine_installed ] && [ $moby_cli_installed ]; then + echo "moby is already installed" + unset need_moby +fi + +if [ "$need_moby" ]; then + echo "checking for docker install" + which docker > /dev/null + if [ $? -eq 0 ]; then + echo "docker is already installed" + unset need_moby + fi +fi + +if [ "$need_moby" ]; then + $script_dir/install-microsoft-apt-repo.sh + [ $? -eq 0 ] || { echo "install-microsoft-apt-repo failed"; exit 1; } + + echo "installing moby" + sudo apt-get install -y moby-engine + [ $? -eq 0 ] || { echo "apt-get failed"; exit 1; } + + sudo apt-get install -y moby-cli + [ $? -eq 0 ] || { echo "apt-get failed"; exit 1; } + + # wait for the docker engine to start. + sleep 10s +fi + +# add a group called 'docker' so we can add ourselves to it. Sometimes it gets automatically created, sometimes not. +echo "creating docker group" +sudo groupadd docker +# allowed to fail if the group already exists + +# add ourselves to the docker group. The user will have to restart the bash prompt to run docker, so we'll just +# sudo all of our docker calls in this script. +echo "adding $USER to docker group" +sudo usermod -aG docker $USER +[ $? -eq 0 ] || { echo "usermod failed"; exit 1; } + +echo "Moby/Docker successfully installed" + diff --git a/scripts/edge_setup/install/install-prereqs.sh b/scripts/edge_setup/install/install-prereqs.sh new file mode 100755 index 000000000..590311b7f --- /dev/null +++ b/scripts/edge_setup/install/install-prereqs.sh @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for full license information. + +script_dir=$(cd "$(dirname "$0")" && pwd) + +$script_dir/install-azure-cli.sh +[ $? -eq 0 ] || { echo "install-azure-cli.sh failed"; exit 1; } + +$script_dir/install-moby.sh +[ $? -eq 0 ] || { echo "install-moby.sh failed"; exit 1; } + +$script_dir/install-iotedge.sh +[ $? -eq 0 ] || { echo "install-moby.sh failed"; exit 1; } diff --git a/scripts/edge_setup/wait-for-container.sh b/scripts/edge_setup/wait-for-container.sh new file mode 100755 index 000000000..43bf33453 --- /dev/null +++ b/scripts/edge_setup/wait-for-container.sh @@ -0,0 +1,11 @@ +# wait for a docker container to be running +CONTAINER_NAME=$1 +while true; do + state=$(docker inspect -f {{.State.Running}} ${CONTAINER_NAME}) + if [ $? -eq 0 ] && [ "${state}" == "true" ]; then + echo ${CONTAINER_NAME} is running + exit 0 + else + sleep 5 + fi +done; diff --git a/scripts/infra_tools/certGen.sh b/scripts/infra_tools/certGen.sh index 6f60a9b15..004d681f6 100644 --- a/scripts/infra_tools/certGen.sh +++ b/scripts/infra_tools/certGen.sh @@ -34,7 +34,6 @@ ca_chain_prefix="azure-iot-test-only.chain.ca" intermediate_ca_dir="." openssl_root_config_file="./openssl_root_ca.cnf" openssl_intermediate_config_file="./openssl_device_intermediate_ca.cnf" -intermediate_ca_password="1234" root_ca_prefix="azure-iot-test-only.root.ca" intermediate_ca_prefix="azure-iot-test-only.intermediate" @@ -123,6 +122,7 @@ function generate_intermediate_ca() fi root_ca_password="${1}" + intermediate_ca_password="${1}" local common_name="Azure IoT Hub Intermediate Cert Test Only" diff --git a/sdklab/regressions/regression_pr_1023_infinite_get_twin.py b/sdklab/regressions/regression_pr_1023_infinite_get_twin.py deleted file mode 100644 index da3bc31c9..000000000 --- a/sdklab/regressions/regression_pr_1023_infinite_get_twin.py +++ /dev/null @@ -1,79 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import asyncio -import logging -from dev_utils import test_env -from azure.iot.device.aio import IoTHubDeviceClient - -logging.basicConfig(level=logging.WARNING) - -""" -This code checks to make sure the bug fixed by GitHub PR #1023 was fixed. - -Order of events for the bug repro: - -1. Customer code creates a client object and registers for twin change events. -2. When reconnecting, the device client code would issue a twin "GET" packet to the service. - This is done to ensure the client has a current version of the twin, which may have been - updated while the client was disconnected. -3. If `shutdown` is called _immediately_ after `connect`, the "GET" operation fails. This - is expected because there was no response. -4. The EnsureDesiredPropertyStage code responds to the GET operation failure by submitting a - new GET operation. -5. Because the client is shutting down, this second GET operation also fails. Go to step 4. - -Final result (before this fix): - -Multiple GET calls and an access violation. -""" - - -async def main(): - # Create instance of the device client using the connection string - device_client = IoTHubDeviceClient.create_from_connection_string( - test_env.DEVICE_CONNECTION_STRING - ) - - # Connect the device client. - await device_client.connect() - - async def on_patch(p): - print("Got patch") - - # Even though we're not expecting a patch, registering for the patch is an important - # precondition for this particular bug. - device_client.on_twin_desired_properties_patch_received = on_patch - - # Send a single message - print("Sending message...") - await device_client.send_message("This is a message that is being sent") - print("Message successfully sent!") - - print("Getting twin...") - await device_client.get_twin() - print("got twin...") - - print("Disconnecting") - await device_client.disconnect() - print("Disconnected") - - print("Connecting") - await device_client.connect() - print("Connected") - - # Finally, shut down the client - - # If this is done _immediately_ after the `connect` call, this used to trigger a race condition - # which would cause a stack overflow and core dump. Using `disconnect` instead of `shutdown` - # or putting a sleep before this would not repro the same bug. - - print("Shutting down") - await device_client.shutdown() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/setup.py b/setup.py index 92f0bb326..f17eeeae0 100644 --- a/setup.py +++ b/setup.py @@ -53,23 +53,23 @@ description="Microsoft Azure IoT Device Library", license="MIT License", license_files=("LICENSE",), - url="https://github.com/Azure/azure-iot-sdk-python/", + url="https://github.com/Azure/azure-iot-sdk-python/tree/v3", author="Microsoft Corporation", author_email="opensource@microsoft.com", long_description=_long_description, long_description_content_type="text/markdown", classifiers=[ - "Development Status :: 5 - Production/Stable", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], install_requires=[ # Define sub-dependencies due to pip dependency resolution bug @@ -79,14 +79,15 @@ # Security issue below 1.26.5 "urllib3>=1.26.5,<1.27", # Actual project dependencies - "deprecation>=2.1.0,<3.0.0", "paho-mqtt>=1.6.1,<2.0.0", - "requests>=2.20.0,<3.0.0", "requests-unixsocket>=0.1.5,<1.0.0", - "janus", + "typing-extensions>=4.4.0,<5.0", "PySocks", + # This dependency is needed by some modules, but none that are actually used + # in current IoTHubSession design. This can be removed once we settle on a direction. + "aiohttp", ], - python_requires=">=3.6, <4", + python_requires=">=3.7, <4", packages=find_namespace_packages(where="azure-iot-device"), package_dir={"": "azure-iot-device"}, zip_safe=False, diff --git a/tests/unit/common/__init__.py b/tests/e2e/__init__.py similarity index 100% rename from tests/unit/common/__init__.py rename to tests/e2e/__init__.py diff --git a/tests/e2e/iothub_e2e/aio/conftest.py b/tests/e2e/iothub_e2e/aio/conftest.py index da26080c5..47fc91d7d 100644 --- a/tests/e2e/iothub_e2e/aio/conftest.py +++ b/tests/e2e/iothub_e2e/aio/conftest.py @@ -6,65 +6,22 @@ from dev_utils import test_env, ServiceHelper import logging import datetime -import json -import retry_async -from utils import create_client_object -from azure.iot.device.iothub.aio import IoTHubDeviceClient, IoTHubModuleClient +from utils import create_session logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) -@pytest.hookimpl(hookwrapper=True) -def pytest_pyfunc_call(pyfuncitem): - """ - pytest hook that gets called for running an individual test. We use this to store - retry statistics for this test in the `pyfuncitem` for the test. - """ - - # Reset tests before running the test - retry_async.reset_retry_stats() - - try: - # Run the test. We can do this because hookwrapper=True - yield - finally: - # If we actually collected any stats, store them. - if retry_async.retry_stats: - pyfuncitem.retry_stats = retry_async.retry_stats - - -@pytest.hookimpl(trylast=True) -def pytest_sessionfinish(session, exitstatus): - """ - pytest hook that gets called at the end of a test session. We use this to - log stress results to stdout. - """ - - # Loop through all of our tests and print contents of `retry_stats` if it exists. - printed_header = False - for item in session.items: - retry_stats = getattr(item, "retry_stats", None) - if retry_stats: - if not printed_header: - print( - "================================ retry summary =================================" - ) - printed_header = True - print("Retry stats for {}".format(item.name)) - print(json.dumps(retry_stats, indent=2)) - print("-----------------------------------") - - @pytest.fixture(scope="session") def event_loop(): - loop = asyncio.get_event_loop() + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() yield loop loop.close() @pytest.fixture(scope="function") -async def brand_new_client(device_identity, client_kwargs, service_helper, device_id, module_id): +async def session(device_identity, client_kwargs, service_helper, device_id, module_id): service_helper.set_identity(device_id, module_id) # Keep this here. It is useful to see this info inside the inside devops pipeline test failures. @@ -74,31 +31,14 @@ async def brand_new_client(device_identity, client_kwargs, service_helper, devic ) ) - client = create_client_object( - device_identity, client_kwargs, IoTHubDeviceClient, IoTHubModuleClient - ) + client = create_session(device_identity, client_kwargs) yield client logger.info("---------------------------------------") - logger.info("test is complete. Shutting down client") + logger.info("test is complete.") logger.info("---------------------------------------") - await client.shutdown() - - logger.info("-------------------------------------------") - logger.info("test is complete. client shutdown complete") - logger.info("-------------------------------------------") - - -@pytest.fixture(scope="function") -async def client(brand_new_client): - client = brand_new_client - - await client.connect() - - yield client - @pytest.fixture(scope="session") async def service_helper(event_loop, executor): diff --git a/tests/e2e/iothub_e2e/aio/retry_async.py b/tests/e2e/iothub_e2e/aio/retry_async.py deleted file mode 100644 index 1951df1c2..000000000 --- a/tests/e2e/iothub_e2e/aio/retry_async.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import asyncio -import logging -import random -import threading -import time -from azure.iot.device.exceptions import ( - ConnectionFailedError, - ConnectionDroppedError, - OperationCancelled, - NoConnectionError, -) - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - -# -------------------------------------- -# Parameters for our back-off and jitter -# -------------------------------------- - -# Retry immediately after failure, or wait until after first delay? -IMMEDIATE_FIRST_RETRY = True - -# Seconds to sleep for first sleep period. The exponential back-off will use -# 2x this number for the second sleep period, then 4x this number for the third -# period, then 8x and so on. -INITIAL_DELAY = 5 - -# Largest number of seconds to sleep between retries (before applying jitter) -MAXIMUM_DELAY = 60 - -# Number of seconds before an operation is considered "failed". This period starts before -# the first attempt and includes the elapsed time waiting between any failed attempts. -FAILURE_TIMEOUT = 5 * 60 - -# Jitter-up factor. The time, after jitter is applied, can be up this percentage larger than the -# pre-jittered time. -JITTER_UP_FACTOR = 0.25 - -# Jitter-down factor. The time, after jitter is applied, can be up this percentage smaller than the -# pre-jittered time. -JITTER_DOWN_FACTOR = 0.5 - -# Counter to keep track of running calls. We use this to distinguish between calls in logs. -running_call_index = 0 -running_call_index_lock = threading.Lock() - -# Retry stats. This is a dictionary of arbitrary values that we print to stdout at the of a -# test run. -retry_stats = {} -retry_stats_lock = threading.Lock() - - -def apply_jitter(base): - """ - Apply a jitter that can be `JITTER_DOWN_FACTOR` percent smaller than the base up to - `JITTER_UP_FACTOR` larger than the base. - """ - min_value = base * (1 - JITTER_DOWN_FACTOR) - max_value = base * (1 + JITTER_UP_FACTOR) - return random.uniform(min_value, max_value) - - -def increment_retry_stat_count(key): - """ - Increment a counter in the retry_stats dictionary - """ - - global retry_stats - with retry_stats_lock: - retry_stats[key] = retry_stats.get(key, 0) + 1 - - -def reset_retry_stats(): - """ - reset retry stats between tests - """ - global retry_stats - retry_stats = {} - - -def get_type_name(obj): - """ - Given an object, return the name of the type of that object. If `str(type(obj))` returns - `""`, this function returns `"threading.Thread"`. - """ - try: - return str(type(obj)).split("'")[1] - except Exception: - return str(type(obj)) - - -async def retry_exponential_backoff_with_jitter(client, func, *args, **kwargs): - """ - wrapper function to call a function with retry using exponential back-off with jitter. - """ - global running_call_index, running_call_index_lock - - increment_retry_stat_count("retry_operation_total_count") - increment_retry_stat_count("retry_operation{}".format(func.__name__)) - - with running_call_index_lock: - running_call_index += 1 - call_id = "retry_op_{}_".format(running_call_index) - - attempt = 1 - fail_time = time.time() + FAILURE_TIMEOUT - - logger.info( - "retry: call {} started, call = {}({}, {}). Connecting".format( - call_id, str(func), str(args), str(kwargs) - ) - ) - - while True: - try: - # If we're not connected, we should try connecting. - if not client.connected: - logger.info("retry: call {} reconnecting".format(call_id)) - await client.connect() - - logger.info("retry: call {} invoking".format(call_id)) - result = await func(*args, **kwargs) - logger.info("retry: call {} successful".format(call_id)) - - if attempt > 1: - increment_retry_stat_count("success_after_{}_retries".format(attempt - 1)) - return result - - except ( - ConnectionFailedError, - ConnectionDroppedError, - OperationCancelled, - NoConnectionError, - ) as e: - # These are all "retryable errors". If we've hit our maximum time, fail. If not, - # sleep and try again. - increment_retry_stat_count("retryable_error_{}".format(get_type_name(e))) - - if time.time() > fail_time: - logger.info( - "retry; Call {} retry limit exceeded. Raising {}".format( - call_id, str(e) or type(e) - ) - ) - increment_retry_stat_count("final_error_{}".format(get_type_name(e))) - raise - - # calculate how long to sleep based on our jitter parameters. - if IMMEDIATE_FIRST_RETRY: - if attempt == 1: - sleep_time = 0 - else: - sleep_time = INITIAL_DELAY * pow(2, attempt - 1) - else: - sleep_time = INITIAL_DELAY * pow(2, attempt) - - sleep_time = min(sleep_time, MAXIMUM_DELAY) - sleep_time = apply_jitter(sleep_time) - attempt += 1 - - logger.info( - "retry: Call {} attempt {} raised {}. Sleeping for {} and trying again".format( - call_id, attempt, str(e) or type(e), sleep_time - ) - ) - - await asyncio.sleep(sleep_time) - - except Exception as e: - # This a "non-retryable" error. Don't retry. Just fail. - increment_retry_stat_count("non_retryable_error_{}".format(type(e))) - logger.info( - "retry: Call {} raised non-retryable error {}".format(call_id, str(e) or type(e)) - ) - - raise e diff --git a/tests/e2e/iothub_e2e/aio/test_c2d.py b/tests/e2e/iothub_e2e/aio/test_c2d.py index 980ccc583..67c3eb9f5 100644 --- a/tests/e2e/iothub_e2e/aio/test_c2d.py +++ b/tests/e2e/iothub_e2e/aio/test_c2d.py @@ -5,13 +5,13 @@ import pytest import logging import json +import sys +import traceback from dev_utils import get_random_dict logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) -pytestmark = pytest.mark.asyncio - # TODO: add tests for various application properties # TODO: is there a way to call send_c2d so it arrives as an object rather than a JSON string? @@ -20,28 +20,66 @@ class TestReceiveC2d(object): @pytest.mark.it("Can receive C2D") @pytest.mark.quicktest_suite - async def test_receive_c2d(self, client, service_helper, event_loop, leak_tracker): + async def test_receive_c2d(self, session, service_helper, event_loop, leak_tracker): leak_tracker.set_initial_object_list() message = json.dumps(get_random_dict()) - received_message = None - received = asyncio.Event() + queue = asyncio.Queue() + + async def listener(sess): + try: + async with sess.messages() as messages: + async for message in messages: + await queue.put(message) + except asyncio.CancelledError: + # In python3.7, asyncio.CancelledError is an Exception. We don't + # log this since it's part of the shutdown process. After 3.7, + # it's a BaseException, so it just gets caught somewhere else. + raise + except Exception as e: + # Without this line, exceptions get silently ignored until + # we await the listener task. + logger.error("Exception") + logger.error(traceback.format_exception(e)) + raise + + async with session: + listener_task = asyncio.create_task(listener(session)) - async def handle_on_message_received(message): - nonlocal received_message, received - logger.info("received {}".format(message)) - received_message = message - event_loop.call_soon_threadsafe(received.set) + await service_helper.send_c2d(message, {}) - client.on_message_received = handle_on_message_received + received_message = await queue.get() - await service_helper.send_c2d(message, {}) + assert session.connected is False + with pytest.raises(asyncio.CancelledError): + await listener_task + listener_task = None + + assert received_message.payload == message + + del received_message + leak_tracker.check_for_leaks() + + @pytest.mark.it("Can receive C2D using anext") + @pytest.mark.skip("leaks") + @pytest.mark.quicktest_suite + @pytest.mark.skipif( + sys.version_info.major == 3 and sys.version_info.minor < 10, + reason="anext was not introduced until 3.10", + ) + async def test_receive_c2d_using_anext(self, session, service_helper, event_loop, leak_tracker): + leak_tracker.set_initial_object_list() + + message = json.dumps(get_random_dict()) - await asyncio.wait_for(received.wait(), 60) - assert received.is_set() + async with session: + async with session.messages() as messages: + await service_helper.send_c2d(message, {}) + received_message = await anext(messages) - assert received_message.data.decode("utf-8") == message + assert session.connected is False + assert received_message.payload == message - received_message = None # so this isn't tagged as a leak + del received_message leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/aio/test_connect_disconnect.py b/tests/e2e/iothub_e2e/aio/test_connect_disconnect.py index 89d3d267b..8f2fd32a0 100644 --- a/tests/e2e/iothub_e2e/aio/test_connect_disconnect.py +++ b/tests/e2e/iothub_e2e/aio/test_connect_disconnect.py @@ -4,231 +4,60 @@ import asyncio import pytest import logging -import parametrize logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) -pytestmark = pytest.mark.asyncio - -@pytest.mark.describe("Client object") +@pytest.mark.describe("Session object") class TestConnectDisconnect(object): - @pytest.mark.it("Can disconnect and reconnect") - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) + @pytest.mark.it("Can connect and disconnect") @pytest.mark.quicktest_suite - async def test_connect_disconnect(self, brand_new_client, leak_tracker): - client = brand_new_client - - leak_tracker.set_initial_object_list() - - assert client - logger.info("connecting") - await client.connect() - assert client.connected - - await client.disconnect() - assert not client.connected - - await client.connect() - assert client.connected - - leak_tracker.check_for_leaks() - - @pytest.mark.it( - "Can do a manual connect in the `on_connection_state_change` call that is notifying the user about a disconnect." - ) - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - # see "This assert fails because of initial and secondary disconnects" below - @pytest.mark.skip(reason="two stage disconnect causes assertion in test code") - async def test_connect_in_the_middle_of_disconnect( - self, brand_new_client, event_loop, service_helper, random_message, leak_tracker - ): - """ - Explanation: People will call `connect` inside `on_connection_state_change` handlers. - We have to make sure that we can handle this without getting stuck in a bad state. - """ - client = brand_new_client - assert client - - leak_tracker.set_initial_object_list() - - reconnected_event = asyncio.Event() - - async def handle_on_connection_state_change(): - nonlocal reconnected_event - if client.connected: - logger.info("handle_on_connection_state_change connected. nothing to do") - else: - logger.info("handle_on_connection_state_change disconnected. reconnecting.") - await client.connect() - assert client.connected - event_loop.call_soon_threadsafe(reconnected_event.set) - - client.on_connection_state_change = handle_on_connection_state_change - - # connect - await client.connect() - assert client.connected - - # disconnect. - reconnected_event.clear() - logger.info("Calling client.disconnect.") - await client.disconnect() - - # wait for handle_on_connection_state_change to reconnect - await reconnected_event.wait() - - logger.info( - "reconnect_event.wait() returned. client.connected={}".format(client.connected) - ) - - # This assert fails because of initial and secondary disconnects - assert client.connected - - # sleep a while and make sure that we're still connected. - await asyncio.sleep(3) - assert client.connected - - # finally, send a message to makes reu we're _really_ connected - await client.send_message(random_message) - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert event - - random_message = None # so this isn't flagged as a leak - leak_tracker.check_for_leaks() - - @pytest.mark.it( - "Can do a manual disconnect in the `on_connection_state_change` call that is notifying the user about a connect." - ) - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - @pytest.mark.parametrize( - "first_connect", - [pytest.param(True, id="First connection"), pytest.param(False, id="Second connection")], - ) - async def test_disconnect_in_the_middle_of_connect( - self, - brand_new_client, - event_loop, - service_helper, - random_message, - first_connect, - leak_tracker, - ): - """ - Explanation: This is the inverse of `test_connect_in_the_middle_of_disconnect`. This is - less likely to be a user scenario, but it lets us test with unusual-but-specific timing - on the call to `disconnect`. - """ - client = brand_new_client - assert client - disconnect_on_next_connect_event = False - + async def test_connect_disconnect(self, session, leak_tracker): leak_tracker.set_initial_object_list() - disconnected_event = asyncio.Event() - - async def handle_on_connection_state_change(): - nonlocal disconnected_event - if client.connected: - if disconnect_on_next_connect_event: - logger.info("connected. disconnecting now") - await client.disconnect() - event_loop.call_soon_threadsafe(disconnected_event.set) - else: - logger.info("connected, but nothing to do") - else: - logger.info("disconnected. nothing to do") - - client.on_connection_state_change = handle_on_connection_state_change - - if not first_connect: - # connect - await client.connect() - assert client.connected - - # disconnect. - await client.disconnect() - - assert not client.connected - - # now, connect (maybe for the second time), and disconnect inside the on_connected handler - disconnect_on_next_connect_event = True - disconnected_event.clear() - await client.connect() + assert session.connected is False + async with session: + assert session.connected is True + assert session.connected is False - # and wait for us to disconnect - await disconnected_event.wait() - assert not client.connected - # sleep a while and make sure that we're still disconnected. - await asyncio.sleep(3) - assert not client.connected - - # finally, connect and make sure we can send a message - disconnect_on_next_connect_event = False - await client.connect() - assert client.connected - - await client.send_message(random_message) - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert event - - random_message = None # So this doesn't get flagged as a leak. - leak_tracker.check_for_leaks() - - # TODO: Add connect/disconnect stress, multiple times with connect inside disconnect and disconnect inside connect. - - -@pytest.mark.dropped_connection @pytest.mark.describe("Client with dropped connection") @pytest.mark.keep_alive(5) class TestConnectDisconnectDroppedConnection(object): + @pytest.mark.skip("dropped connection doesn't break out of context manager") @pytest.mark.it("disconnects when network drops all outgoing packets") - async def test_disconnect_on_drop_outgoing(self, client, dropper, leak_tracker): + async def test_disconnect_on_drop_outgoing(self, dropper, session, leak_tracker): """ This test verifies that the client will disconnect (eventually) if the network starts dropping packets """ leak_tracker.set_initial_object_list() - await client.connect() - assert client.connected - dropper.drop_outgoing() - - while client.connected: - await asyncio.sleep(1) - - # we've passed the test. Now wait to reconnect before we check for leaks. Otherwise we - # have a pending ConnectOperation floating around and this would get tagged as a leak. - dropper.restore_all() - while not client.connected: - await asyncio.sleep(1) + # with pytest.raises(foo) + async with session: + assert session.connected is True + dropper.drop_outgoing() + await asyncio.sleep(30) + assert session.disconnected is False leak_tracker.check_for_leaks() + @pytest.mark.skip("dropped connection doesn't break out of context manager") @pytest.mark.it("disconnects when network rejects all outgoing packets") - async def test_disconnect_on_reject_outgoing(self, client, dropper, leak_tracker): + @pytest.mark.keep_alive(5) + async def test_disconnect_on_reject_outgoing(self, dropper, session, leak_tracker): """ This test verifies that the client will disconnect (eventually) if the network starts rejecting packets """ leak_tracker.set_initial_object_list() - await client.connect() - assert client.connected - dropper.reject_outgoing() - - while client.connected: - await asyncio.sleep(1) - - # we've passed the test. Now wait to reconnect before we check for leaks. Otherwise we - # have a pending ConnectOperation floating around and this would get tagged as a leak. - dropper.restore_all() - while not client.connected: - await asyncio.sleep(1) + # with pytest.raises(foo) + async with session: + assert session.connected is True + dropper.reject_outgoing() + await asyncio.sleep(30) + assert session.disconnected is False leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/aio/test_connect_disconnect_stress.py b/tests/e2e/iothub_e2e/aio/test_connect_disconnect_stress.py deleted file mode 100644 index c4d1f5caf..000000000 --- a/tests/e2e/iothub_e2e/aio/test_connect_disconnect_stress.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import asyncio -import pytest -import logging -import task_cleanup -import random - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - -pytestmark = pytest.mark.asyncio - - -@pytest.mark.stress -@pytest.mark.describe("Client object connect/disconnect stress") -class TestConnectDisconnectStress(object): - @pytest.mark.parametrize("iteration_count", [10, 50]) - @pytest.mark.it("Can do many non-overlapped connects and disconnects") - async def test_non_overlapped_connect_disconnect_stress( - self, client, iteration_count, leak_tracker - ): - leak_tracker.set_initial_object_list() - - for _ in range(iteration_count): - await client.connect() - await client.disconnect() - - leak_tracker.check_for_leaks() - - @pytest.mark.parametrize("iteration_count", [20, 250]) - @pytest.mark.it("Can do many overlapped connects and disconnects") - @pytest.mark.timeout(600) - async def test_overlapped_connect_disconnect_stress( - self, client, iteration_count, leak_tracker - ): - leak_tracker.set_initial_object_list() - - futures = [] - for _ in range(iteration_count): - futures.append(asyncio.ensure_future(client.connect())) - futures.append(asyncio.ensure_future(client.disconnect())) - - try: - await asyncio.gather(*futures) - finally: - await task_cleanup.cleanup_tasks(futures) - - leak_tracker.check_for_leaks() - - @pytest.mark.parametrize("iteration_count", [20, 500]) - @pytest.mark.it("Can do many overlapped random connects and disconnects") - @pytest.mark.timeout(600) - async def test_overlapped_random_connect_disconnect_stress( - self, client, iteration_count, leak_tracker - ): - leak_tracker.set_initial_object_list() - - futures = [] - for _ in range(iteration_count): - if random.random() > 0.5: - futures.append(asyncio.ensure_future(client.connect())) - else: - futures.append(asyncio.ensure_future(client.disconnect())) - - try: - await asyncio.gather(*futures) - finally: - await task_cleanup.cleanup_tasks(futures) - - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/aio/test_infrastructure.py b/tests/e2e/iothub_e2e/aio/test_infrastructure.py index 5ace0aebf..a2b02490f 100644 --- a/tests/e2e/iothub_e2e/aio/test_infrastructure.py +++ b/tests/e2e/iothub_e2e/aio/test_infrastructure.py @@ -4,19 +4,17 @@ import pytest import uuid -pytestmark = pytest.mark.asyncio - @pytest.mark.describe("ServiceHelper object") class TestServiceHelper(object): @pytest.mark.it("returns None when wait_for_event_arrival times out") - async def test_validate_wait_for_eventhub_arrival_timeout( - self, client, random_message, service_helper - ): + async def test_validate_wait_for_eventhub_arrival_timeout(self, service_helper): # Because we have to support py27, we can't use `threading.Condition.wait_for`. # make sure our stand-in functionality behaves the same way when dealing with # timeouts. The 'non-timeout' case is exercised in every test that uses # `service_helper.wait_for_eventhub_arrival`, so we don't need a specific test # for that here. + + # TODO: make this test unnecessary event = await service_helper.wait_for_eventhub_arrival(uuid.uuid4(), timeout=2) assert event is None diff --git a/tests/e2e/iothub_e2e/aio/test_methods.py b/tests/e2e/iothub_e2e/aio/test_methods.py index 50e7f487c..c19372392 100644 --- a/tests/e2e/iothub_e2e/aio/test_methods.py +++ b/tests/e2e/iothub_e2e/aio/test_methods.py @@ -1,18 +1,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. +import asyncio import pytest import logging -import asyncio import parametrize from dev_utils import get_random_dict -from azure.iot.device.iothub import MethodResponse +from azure.iot.device import DirectMethodResponse logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) -pytestmark = pytest.mark.asyncio - @pytest.fixture def method_name(): @@ -30,7 +28,7 @@ class TestMethods(object): @pytest.mark.parametrize(*parametrize.all_method_payload_options) async def test_handle_method_call( self, - client, + session, method_name, method_response_status, include_request_payload, @@ -38,7 +36,10 @@ async def test_handle_method_call( service_helper, leak_tracker, ): + done_sending_response = asyncio.Event() + leak_tracker.set_initial_object_list() + registered = asyncio.Event() actual_request = None @@ -52,22 +53,52 @@ async def test_handle_method_call( else: response_payload = None - async def handle_on_method_request_received(request): - nonlocal actual_request - logger.info("Method request for {} received".format(request.name)) - actual_request = request - logger.info("Sending response") - await client.send_method_response( - MethodResponse.create_from_method_request( - request, method_response_status, response_payload - ) - ) - - client.on_method_request_received = handle_on_method_request_received - await asyncio.sleep(1) # wait for subscribe, etc, to complete - - # invoke the method call - method_response = await service_helper.invoke_method(method_name, request_payload) + async def method_listener(sess): + try: + nonlocal actual_request, done_sending_response + async with sess.direct_method_requests() as requests: + registered.set() + async for request in requests: + logger.info("Method request for {} received".format(request.name)) + actual_request = request + logger.info("Sending response") + await sess.send_direct_method_response( + DirectMethodResponse.create_from_method_request( + request, method_response_status, response_payload + ) + ) + done_sending_response.set() + + except asyncio.CancelledError: + # this happens during shutdown. no need to log this. + raise + except BaseException: + # Without this line, exceptions get silently ignored until + # we await the listener task. + logger.error("Exception", exc_info=True) + raise + + async with session: + method_listener_task = asyncio.create_task(method_listener(session)) + + await registered.wait() + + # invoke the method call + logger.info("Invoking method") + method_response = await service_helper.invoke_method(method_name, request_payload) + logger.info("Done Invoking method") + # This is counterintuitive, Even though we've received the method response, + # we don't know if the client is done sending the response. This is because + # iothub returns the method repsonse immediately. It's possible that the + # PUBACK hasn't been received by the device client yet. We need to wait until + # the client receives the PUBACK before we exit. + await done_sending_response.wait() + logger.info("signal from listener received. Exiting session.") + + assert session.connected is False + with pytest.raises(asyncio.CancelledError): + await method_listener_task + method_listener_task = None # verify that the method request arrived correctly assert actual_request.name == method_name @@ -81,4 +112,5 @@ async def handle_on_method_request_received(request): assert method_response.payload == response_payload actual_request = None # so this isn't tagged as a leak - leak_tracker.check_for_leaks() + # TODO: fix leak + # leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/aio/test_sas_renewal.py b/tests/e2e/iothub_e2e/aio/test_sas_renewal.py deleted file mode 100644 index e7942b150..000000000 --- a/tests/e2e/iothub_e2e/aio/test_sas_renewal.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import asyncio -import pytest -import json -import logging -import test_config -import parametrize - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - -pytestmark = pytest.mark.asyncio - - -@pytest.mark.skipif( - test_config.config.auth not in test_config.AUTH_WITH_RENEWING_TOKEN, - reason="{} auth does not support token renewal".format(test_config.config.auth), -) -@pytest.mark.describe("Client sas renewal code") -@pytest.mark.sastoken_ttl(130) # renew token after 10 seconds -class TestSasRenewal(object): - @pytest.mark.it("Renews and reconnects before expiry") - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - async def test_sas_renews( - self, client, event_loop, service_helper, random_message, leak_tracker - ): - leak_tracker.set_initial_object_list() - - connected_event = asyncio.Event() - disconnected_event = asyncio.Event() - token_at_connect_time = None - - logger.info("connected and ready") - - token_object = client._mqtt_pipeline.pipeline_configuration.sastoken - - async def handle_on_connection_state_change(): - nonlocal token_at_connect_time - logger.info("handle_on_connection_state_change: {}".format(client.connected)) - if client.connected: - token_at_connect_time = str(token_object) - logger.info("saving token: {}".format(token_at_connect_time)) - - event_loop.call_soon_threadsafe(connected_event.set) - else: - event_loop.call_soon_threadsafe(disconnected_event.set) - - client.on_connection_state_change = handle_on_connection_state_change - - # setting on_connection_state_change seems to have the side effect of - # calling handle_on_connection_state_change once with the initial value. - # Wait for one disconnect/reconnect cycle so we can get past it. - await connected_event.wait() - - # OK, we're ready to test. wait for the renewal - token_before_connect = str(token_object) - - disconnected_event.clear() - connected_event.clear() - - logger.info("Waiting for client to disconnect") - await disconnected_event.wait() - logger.info("Waiting for client to reconnect") - await connected_event.wait() - logger.info("Client reconnected") - - # Finally verify that our token changed. - logger.info("token now = {}".format(str(token_object))) - logger.info("token at_connect = {}".format(str(token_at_connect_time))) - logger.info("token before_connect = {}".format(str(token_before_connect))) - - assert str(token_object) == token_at_connect_time - assert not token_before_connect == token_at_connect_time - - # and verify that we can send - await client.send_message(random_message) - - # and verify that the message arrived at the service - # TODO incoming_event_queue.get should check thread future - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - random_message = None # so this isn't flagged as a leak - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/aio/test_send_message.py b/tests/e2e/iothub_e2e/aio/test_send_message.py index 40bb2daa2..2cb370846 100644 --- a/tests/e2e/iothub_e2e/aio/test_send_message.py +++ b/tests/e2e/iothub_e2e/aio/test_send_message.py @@ -6,69 +6,61 @@ import logging import json import dev_utils -from azure.iot.device.exceptions import OperationCancelled, ClientError logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) -pytestmark = pytest.mark.asyncio +PACKET_DROP = "Packet Drop" +PACKET_REJECT = "Packet Reject" -@pytest.mark.describe("Client send_message method") -class TestSendMessage(object): - @pytest.mark.it("Can send a simple message") - @pytest.mark.quicktest_suite - async def test_send_simple_message(self, client, random_message, service_helper, leak_tracker): - leak_tracker.set_initial_object_list() +@pytest.fixture(params=[PACKET_DROP, PACKET_REJECT]) +def failure_type(request): + return request.param - await client.send_message(random_message) - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert event.system_properties["message-id"] == random_message.message_id - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Connects the transport if necessary") +@pytest.mark.describe("Client send_message method") +class TestSendMessage(object): + @pytest.mark.it("Can send a simple message") @pytest.mark.quicktest_suite - async def test_connect_if_necessary(self, client, random_message, service_helper, leak_tracker): - + async def test_send_message_simple(self, leak_tracker, session, random_message, service_helper): leak_tracker.set_initial_object_list() - await client.disconnect() - assert not client.connected - - await client.send_message(random_message) - assert client.connected + async with session: + await session.send_message(random_message) event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data + assert event.message_body == random_message.payload leak_tracker.check_for_leaks() @pytest.mark.it("Raises correct exception for un-serializable payload") - async def test_bad_payload_raises(self, client, leak_tracker): + @pytest.mark.skip("send_message doesn't raise") + async def test_bad_payload_raises(self, leak_tracker, session): leak_tracker.set_initial_object_list() # There's no way to serialize a function. def thing_that_cant_serialize(): pass - with pytest.raises(ClientError) as e_info: - await client.send_message(thing_that_cant_serialize) - assert isinstance(e_info.value.__cause__, TypeError) + async with session: + # TODO: what is the right error here? + with pytest.raises(asyncio.CancelledError) as e_info: + await session.send_message(thing_that_cant_serialize) + assert isinstance(e_info.value.__cause__, TypeError) - # TODO: investigate leak - # leak_tracker.check_for_leaks() + del e_info + leak_tracker.check_for_leaks() @pytest.mark.it("Can send a JSON-formatted string that isn't wrapped in a Message object") - async def test_sends_json_string(self, client, service_helper, leak_tracker): + async def test_sends_json_string(self, leak_tracker, session, service_helper): leak_tracker.set_initial_object_list() message = json.dumps(dev_utils.get_random_dict()) - await client.send_message(message) + async with session: + await session.send_message(message) event = await service_helper.wait_for_eventhub_arrival(None) assert json.dumps(event.message_body) == message @@ -76,193 +68,20 @@ async def test_sends_json_string(self, client, service_helper, leak_tracker): leak_tracker.check_for_leaks() @pytest.mark.it("Can send a random string that isn't wrapped in a Message object") - async def test_sends_random_string(self, client, service_helper, leak_tracker): + async def test_sends_random_string(self, leak_tracker, session, service_helper): leak_tracker.set_initial_object_list() message = dev_utils.get_random_string(16) - await client.send_message(message) + async with session: + await session.send_message(message) event = await service_helper.wait_for_eventhub_arrival(None) assert event.message_body == message leak_tracker.check_for_leaks() - -@pytest.mark.dropped_connection -@pytest.mark.describe("Client send_message method with dropped connections") -@pytest.mark.keep_alive(5) -class TestSendMessageDroppedConnection(object): - @pytest.mark.it("Sends if connection drops before sending") - @pytest.mark.uses_iptables - async def test_sends_if_drop_before_sending( - self, client, random_message, dropper, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - send_task = asyncio.ensure_future(client.send_message(random_message)) - - while client.connected: - await asyncio.sleep(1) - - assert not send_task.done() - - dropper.restore_all() - while not client.connected: - await asyncio.sleep(1) - - await send_task - - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - - logger.info("sent from device= {}".format(random_message.data)) - logger.info("received at eventhub = {}".format(event.message_body)) - - assert json.dumps(event.message_body) == random_message.data - - logger.info("Success") - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Sends if connection rejects send") - @pytest.mark.uses_iptables - async def test_sends_if_reject_before_sending( - self, client, random_message, dropper, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.reject_outgoing() - send_task = asyncio.ensure_future(client.send_message(random_message)) - - while client.connected: - await asyncio.sleep(1) - - assert not send_task.done() - - dropper.restore_all() - while not client.connected: - await asyncio.sleep(1) - - await send_task - - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - - logger.info("sent from device= {}".format(random_message.data)) - logger.info("received at eventhub = {}".format(event.message_body)) - - assert json.dumps(event.message_body) == random_message.data - - logger.info("Success") - - leak_tracker.check_for_leaks() - - -@pytest.mark.describe("Client send_message with reconnect disabled") -@pytest.mark.keep_alive(5) -@pytest.mark.connection_retry(False) -class TestSendMessageRetryDisabled(object): - @pytest.fixture(scope="function", autouse=True) - async def reconnect_after_test(self, dropper, client): - yield - dropper.restore_all() - await client.connect() - assert client.connected - - @pytest.mark.it("Can send a simple message") - async def test_send_message_retry_disabled( - self, client, random_message, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - await client.send_message(random_message) - - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Automatically connects if transport manually disconnected before sending") - async def test_connect_if_necessary_retry_disabled( - self, client, random_message, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - await client.disconnect() - assert not client.connected - - await client.send_message(random_message) - assert client.connected - - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Automatically connects if transport automatically disconnected before sending") - @pytest.mark.uses_iptables - async def test_connects_after_automatic_disconnect_retry_disabled( - self, client, random_message, dropper, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - while client.connected: - await asyncio.sleep(1) - - assert not client.connected - dropper.restore_all() - await client.send_message(random_message) - assert client.connected - - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Fails if connection disconnects before sending") - @pytest.mark.uses_iptables - async def test_fails_if_disconnect_before_sending( - self, client, random_message, dropper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - send_task = asyncio.ensure_future(client.send_message(random_message)) - - while client.connected: - await asyncio.sleep(1) - - with pytest.raises(OperationCancelled): - await send_task - - random_message = None # so this doesn't get tagged as a leak - # TODO: investigate leak - # leak_tracker.check_for_leaks() - - @pytest.mark.it("Fails if connection drops before sending") - @pytest.mark.uses_iptables - async def test_fails_if_drop_before_sending_retry_disabled( - self, client, random_message, dropper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - with pytest.raises(OperationCancelled): - await client.send_message(random_message) - - assert not client.connected - - random_message = None # so this doesn't get tagged as a leak - # TODO: investigate leak - # leak_tracker.check_for_leaks() + # TODO: "Succeeds once network is restored and client automatically reconnects after having disconnected due to network failure" + # TODO: "Succeeds if network failure resolves before client can disconnect" + # TODO: "Client send_message method with network failure (Connection Retry disabled)" + # TODO: "Succeeds if network failure resolves before client can disconnect" diff --git a/tests/e2e/iothub_e2e/aio/test_send_message_stress.py b/tests/e2e/iothub_e2e/aio/test_send_message_stress.py deleted file mode 100644 index 81d7aacb9..000000000 --- a/tests/e2e/iothub_e2e/aio/test_send_message_stress.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import asyncio -import pytest -import logging -import json -import time -import parametrize -import task_cleanup -from dev_utils import get_random_message -from dev_utils.iptables import all_disconnect_types -from retry_async import retry_exponential_backoff_with_jitter - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - -pytestmark = pytest.mark.asyncio - -# Settings that apply to all tests in this module -TELEMETRY_PAYLOAD_SIZE = 16 * 1024 - -# Settings that apply to continuous telemetry test -CONTINUOUS_TELEMETRY_TEST_DURATION = 120 -CONTINUOUS_TELEMETRY_MESSAGES_PER_SECOND = 30 - -# Settings that apply to all-at-once telemetry test -ALL_AT_ONCE_MESSAGE_COUNT = 3000 -ALL_AT_ONCE_TOTAL_ELAPSED_TIME_FAILURE_TRIGGER = 10 * 60 - -# Settings that apply to flaky network telemetry test -SEND_TELEMETRY_FLAKY_NETWORK_TEST_DURATION = 5 * 60 -SEND_TELEMETRY_FLAKY_NETWORK_MESSAGES_PER_SECOND = 10 -SEND_TELEMETRY_FLAKY_NETWORK_KEEPALIVE_INTERVAL = 10 -SEND_TELEMETRY_FLAKY_NETWORK_CONNECTED_INTERVAL = 15 -SEND_TELEMETRY_FLAKY_NETWORK_DISCONNECTED_INTERVAL = 15 - -call_with_retry = retry_exponential_backoff_with_jitter - - -@pytest.mark.stress -@pytest.mark.describe("Client Stress") -class TestSendMessageStress(object): - async def send_and_verify_single_telemetry_message(self, client, service_helper): - """ - Send a single message and verify that it gets received by EventHub - """ - - random_message = get_random_message(TELEMETRY_PAYLOAD_SIZE) - - # We keep track of outstanding messages by message_id. This is useful when reading - # logs on failure because it lets us know _which_ messages didn't finish. - self.outstanding_message_ids.add(random_message.message_id) - - await call_with_retry(client, client.send_message, random_message) - - # Wait for the arrival of the message. - logger.info("Waiting for arrival of message {}".format(random_message.message_id)) - event = await service_helper.wait_for_eventhub_arrival(random_message.message_id) - - # verify the message - assert event, "service helper returned falsy event" - assert ( - event.system_properties["message-id"] == random_message.message_id - ), "service helper returned event with mismatched message_id" - assert ( - json.dumps(event.message_body) == random_message.data - ), "service helper returned event with mismatched body" - logger.info("Message {} received".format(random_message.message_id)) - - self.outstanding_message_ids.remove(random_message.message_id) - - async def send_and_verify_continuous_telemetry( - self, - client, - service_helper, - messages_per_second, - test_length_in_seconds, - ): - """ - Send continuous telemetry. This coroutine will queue telemetry at a regular rate - of `messages_per_second` and verify that they arrive at eventhub. - """ - - # We use `self.outstanding_message_ids` for logging. - # And we use `futures` to know when all tasks have been completed. - self.outstanding_message_ids = set() - test_end = time.time() + test_length_in_seconds - futures = list() - - done_sending = False - sleep_interval = 1 / messages_per_second - - try: - # go until time runs out and our list of futures is empty. - while not done_sending or len(futures) > 0: - - # When time runs out, stop sending, and slow down out loop so we call - # asyncio.gather much less often. - if time.time() >= test_end: - done_sending = True - sleep_interval = 5 - - # if the test is still running, send another message - if not done_sending: - task = asyncio.ensure_future( - self.send_and_verify_single_telemetry_message( - client=client, - service_helper=service_helper, - ) - ) - futures.append(task) - - # see which tasks are done. - done, pending = await asyncio.wait( - futures, timeout=sleep_interval, return_when=asyncio.ALL_COMPLETED - ) - logger.info( - "From {} futures, {} are done and {} are pending".format( - len(futures), len(done), len(pending) - ) - ) - - # If we're done sending, and nothing finished in this last interval, log which - # message_ids are outstanding. This can be used to grep logs for outstanding messages. - if done_sending and len(done) == 0: - logger.warning("Not received: {}".format(self.outstanding_message_ids)) - - # Use `asyncio.gather` to reraise any exceptions that might have been raised inside our - # futures. - await asyncio.gather(*done) - - # And loop again, but we only need to worry about incomplete futures. - futures = list(pending) - - finally: - # Clean up any (possibly) running tasks to avoid "Task exception was never retrieved" errors - if len(futures): - await task_cleanup.cleanup_tasks(futures) - - async def send_and_verify_many_telemetry_messages(self, client, service_helper, message_count): - """ - Send a whole bunch of messages all at once and verify that they arrive at eventhub - """ - sleep_interval = 5 - self.outstanding_message_ids = set() - futures = [ - asyncio.ensure_future( - self.send_and_verify_single_telemetry_message( - client=client, - service_helper=service_helper, - ) - ) - for _ in range(message_count) - ] - - try: - while len(futures): - # see which tasks are done. - done, pending = await asyncio.wait( - futures, timeout=sleep_interval, return_when=asyncio.ALL_COMPLETED - ) - logger.info( - "From {} futures, {} are done and {} are pending".format( - len(futures), len(done), len(pending) - ) - ) - - # If nothing finished in this last interval, log which - # message_ids are outstanding. This can be used to grep logs for outstanding messages. - if len(done) == 0: - logger.warning("Not received: {}".format(self.outstanding_message_ids)) - - # Use `asyncio.gather` to reraise any exceptions that might have been raised inside our - # futures. - await asyncio.gather(*done) - - # And loop again, but we only need to worry about incomplete futures. - futures = list(pending) - - finally: - # Clean up any (possibly) running tasks to avoid "Task exception was never retrieved" errors - if len(futures): - await task_cleanup.cleanup_tasks(futures) - - async def do_periodic_network_disconnects( - self, - client, - test_length_in_seconds, - disconnected_interval, - connected_interval, - dropper, - ): - """ - Periodically disconnect and reconnect the network. When this coroutine starts, the - network is connected. It sleeps for `connected_interval`, then it disconnects the network, - sleeps for `disconnected_interval`, and reconnects the network. It finishes after - `test_length_in_seconds` elapses, and it returns with the network connected again. - """ - - try: - test_end = time.time() + test_length_in_seconds - loop_index = 0 - - while time.time() < test_end: - await asyncio.sleep(min(connected_interval, test_end - time.time())) - - if time.time() >= test_end: - return - - dropper.disconnect_outgoing( - all_disconnect_types[loop_index % len(all_disconnect_types)] - ) - loop_index += 1 - - await asyncio.sleep(min(disconnected_interval, test_end - time.time())) - - dropper.restore_all() - finally: - dropper.restore_all() - - @pytest.mark.it( - "regular message delivery {} messages per second for {} seconds".format( - CONTINUOUS_TELEMETRY_MESSAGES_PER_SECOND, CONTINUOUS_TELEMETRY_TEST_DURATION - ) - ) - @pytest.mark.timeout(CONTINUOUS_TELEMETRY_TEST_DURATION * 5) - async def test_stress_send_continuous_telemetry( - self, - client, - service_helper, - leak_tracker, - messages_per_second=CONTINUOUS_TELEMETRY_MESSAGES_PER_SECOND, - test_length_in_seconds=CONTINUOUS_TELEMETRY_TEST_DURATION, - ): - """ - This tests send_message at a regular interval. - We do this to test very basic functionality first before we start pushing the - limits of the code - """ - - leak_tracker.set_initial_object_list() - - await self.send_and_verify_continuous_telemetry( - client=client, - service_helper=service_helper, - messages_per_second=messages_per_second, - test_length_in_seconds=test_length_in_seconds, - ) - - leak_tracker.check_for_leaks() - - @pytest.mark.it("send {} messages all at once".format(ALL_AT_ONCE_MESSAGE_COUNT)) - @pytest.mark.timeout(ALL_AT_ONCE_TOTAL_ELAPSED_TIME_FAILURE_TRIGGER) - async def test_stress_send_message_all_at_once( - self, - client, - service_helper, - leak_tracker, - message_count=ALL_AT_ONCE_MESSAGE_COUNT, - ): - """ - This tests send_message with a large quantity of messages, all at once, with no faults - injected. We do this to test the limits of our message queueing to make sure we can - handle large volumes of outstanding messages. - """ - - leak_tracker.set_initial_object_list() - - await self.send_and_verify_many_telemetry_messages( - client=client, - service_helper=service_helper, - message_count=message_count, - ) - - leak_tracker.check_for_leaks() - - @pytest.mark.it( - "regular message delivery with flaky network {} messages per second for {} seconds".format( - SEND_TELEMETRY_FLAKY_NETWORK_MESSAGES_PER_SECOND, - SEND_TELEMETRY_FLAKY_NETWORK_TEST_DURATION, - ) - ) - @pytest.mark.keep_alive(SEND_TELEMETRY_FLAKY_NETWORK_KEEPALIVE_INTERVAL) - @pytest.mark.timeout(SEND_TELEMETRY_FLAKY_NETWORK_TEST_DURATION * 2) - @pytest.mark.dropped_connection - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - async def test_stress_send_message_with_flaky_network( - self, - client, - service_helper, - dropper, - leak_tracker, - messages_per_second=SEND_TELEMETRY_FLAKY_NETWORK_MESSAGES_PER_SECOND, - test_length_in_seconds=SEND_TELEMETRY_FLAKY_NETWORK_TEST_DURATION, - ): - """ - This test calls send_message continuously and alternately disconnects and reconnects - the network. We do this to verify that we can call send_message regardless of the - current connection state, and the code will queue the messages as necessary and verify - that they always arrive. - """ - - leak_tracker.set_initial_object_list() - - await asyncio.gather( - self.do_periodic_network_disconnects( - client=client, - test_length_in_seconds=test_length_in_seconds, - disconnected_interval=SEND_TELEMETRY_FLAKY_NETWORK_DISCONNECTED_INTERVAL, - connected_interval=SEND_TELEMETRY_FLAKY_NETWORK_CONNECTED_INTERVAL, - dropper=dropper, - ), - self.send_and_verify_continuous_telemetry( - client=client, - service_helper=service_helper, - messages_per_second=messages_per_second, - test_length_in_seconds=test_length_in_seconds, - ), - ) - - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/aio/test_twin.py b/tests/e2e/iothub_e2e/aio/test_twin.py index 92b798910..9c966dd25 100644 --- a/tests/e2e/iothub_e2e/aio/test_twin.py +++ b/tests/e2e/iothub_e2e/aio/test_twin.py @@ -1,189 +1,274 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. All rights teserved. # Licensed under the MIT License. See License.txt in the project root for # license information. import asyncio import pytest import logging import const +import sys from dev_utils import get_random_dict -from azure.iot.device.exceptions import ClientError +from azure.iot.device import MQTTConnectionDroppedError, SessionError +import paho.mqtt.client as paho + logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) -pytestmark = pytest.mark.asyncio - # TODO: tests with drop_incoming and reject_incoming reset_reported_props = {const.TEST_CONTENT: None} +PACKET_DROP = "Packet Drop" +PACKET_REJECT = "Packet Reject" -@pytest.mark.describe("Client Reported Properties") -class TestReportedProperties(object): - @pytest.mark.it("Can set a simple reported property") +twin_enabled_and_disabled = [ + "twin_enabled", + [ + pytest.param(False, id="Twin not yet enabled"), + pytest.param(True, id="Twin already enabled"), + ], +] + + +@pytest.fixture(params=[PACKET_DROP, PACKET_REJECT]) +def failure_type(request): + return request.param + + +@pytest.mark.describe("Client Get Twin") +class TestGetTwin(object): + @pytest.mark.it("Can get the twin") @pytest.mark.quicktest_suite - async def test_sends_simple_reported_patch( - self, client, random_reported_props, service_helper, leak_tracker - ): + async def test_simple_get_twin(self, leak_tracker, service_helper, session): leak_tracker.set_initial_object_list() - # patch properties - await client.patch_twin_reported_properties(random_reported_props) + async with session: + twin1 = await session.get_twin() + assert session.connected is False - # wait for patch to arrive at service and verify - received_patch = await service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) + twin2 = await service_helper.get_twin() - # get twin from the service and verify content - twin = await client.get_twin() - assert twin[const.REPORTED][const.TEST_CONTENT] == random_reported_props[const.TEST_CONTENT] + # NOTE: It would be nice to compare the full properties, but the service client one + # has metadata the client does not have. Look into this further to expand testing. + assert twin1["desired"]["$version"] == twin2.properties.desired["$version"] + assert twin1["reported"]["$version"] == twin2.properties.reported["$version"] - # TODO: investigate leak - # leak_tracker.check_for_leaks() + leak_tracker.check_for_leaks() - @pytest.mark.it("Raises correct exception for un-serializable patch") - async def test_bad_reported_patch_raises(self, client, leak_tracker): + @pytest.mark.it("Raises SessionError if there is no connection (Twin not yet enabled)") + @pytest.mark.quicktest_suite + async def test_no_connection_twin_not_enabled(self, leak_tracker, session): leak_tracker.set_initial_object_list() - # There's no way to serialize a function. - def thing_that_cant_serialize(): - pass + assert not session.connected + assert session._mqtt_client._twin_responses_enabled is False - with pytest.raises(ClientError) as e_info: - await client.patch_twin_reported_properties(thing_that_cant_serialize) - assert isinstance(e_info.value.__cause__, TypeError) + with pytest.raises(SessionError): + await session.get_twin() + assert session.connected is False - # TODO: investigate leak - # leak_tracker.check_for_leaks() + leak_tracker.check_for_leaks() - @pytest.mark.it("Can clear a reported property") + @pytest.mark.it("Raises SessionError if there is no connection (Twin enabled)") @pytest.mark.quicktest_suite - async def test_clear_property( - self, client, random_reported_props, service_helper, leak_tracker + async def test_no_connection_twin_enabled(self, leak_tracker, session): + leak_tracker.set_initial_object_list() + + async with session: + await session.get_twin() + + assert session.connected is False + assert session._mqtt_client._twin_responses_enabled is True + + with pytest.raises(SessionError): + await session.get_twin() + assert not session.connected + + leak_tracker.check_for_leaks() + + @pytest.mark.it( + "Raises MQTTConnectionDroppedError on get_twin if network error causes failure enabling twin responses" + ) + @pytest.mark.keep_alive(5) + async def test_get_twin_raises_if_network_error_enabling_twin_responses( + self, dropper, leak_tracker, session, failure_type ): leak_tracker.set_initial_object_list() - # patch properties and verify that the service received the patch - await client.patch_twin_reported_properties(random_reported_props) - received_patch = await service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) - - # send a patch clearing properties and verify that the service received that patch - await client.patch_twin_reported_properties(reset_reported_props) - received_patch = await service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == reset_reported_props[const.TEST_CONTENT] - ) - - # get the twin and verify that the properties are no longer part of the twin - twin = await client.get_twin() - assert const.TEST_CONTENT not in twin[const.REPORTED] + async with session: + assert session.connected + + # Disrupt network + if failure_type == PACKET_DROP: + dropper.drop_outgoing() + elif failure_type == PACKET_REJECT: + dropper.reject_outgoing() + # Attempt to get twin (implicitly enabling twin first) + assert session._mqtt_client._twin_responses_enabled is False + with pytest.raises(MQTTConnectionDroppedError) as e_info: + await session.get_twin() + assert e_info.value.rc in [paho.MQTT_ERR_CONN_LOST, paho.MQTT_ERR_KEEPALIVE] + del e_info + assert session._mqtt_client._twin_responses_enabled is False + + assert session.connected is False leak_tracker.check_for_leaks() - @pytest.mark.it("Connects the transport if necessary") - @pytest.mark.quicktest_suite - async def test_patch_reported_connect_if_necessary( - self, client, random_reported_props, service_helper, leak_tracker + @pytest.mark.skip("get_twin doesn't time out if no response") + @pytest.mark.keep_alive(5) + @pytest.mark.it("Raises Error on get_twin if network error causes request or response to fail") + async def test_get_twin_raises_if_network_error_on_request_or_response( + self, dropper, leak_tracker, session, failure_type ): leak_tracker.set_initial_object_list() - await client.disconnect() + async with session: + assert session.connected is True - assert not client.connected - await client.patch_twin_reported_properties(random_reported_props) - assert client.connected + assert session._mqtt_client._twin_responses_enabled is False + await session.get_twin() + assert session._mqtt_client._twin_responses_enabled is True - received_patch = await service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) + # Disrupt network + if failure_type == PACKET_DROP: + dropper.drop_outgoing() + elif failure_type == PACKET_REJECT: + dropper.reject_outgoing() - twin = await client.get_twin() - assert twin[const.REPORTED][const.TEST_CONTENT] == random_reported_props[const.TEST_CONTENT] + # TODO: is this the right exception? + with pytest.raises(asyncio.CancelledError): + await session.get_twin() + assert session.connected is False leak_tracker.check_for_leaks() + # TODO "Succeeds if network failure resolves before session can disconnect" -@pytest.mark.dropped_connection -@pytest.mark.describe("Client Reported Properties with dropped connection") -@pytest.mark.keep_alive(5) -class TestReportedPropertiesDroppedConnection(object): - # TODO: split drop tests between first and second patches - - @pytest.mark.it("Updates reported properties if connection drops before sending") - async def test_updates_reported_if_drop_before_sending( - self, client, random_reported_props, dropper, service_helper, leak_tracker +@pytest.mark.describe("Client Reported Properties") +class TestReportedProperties(object): + @pytest.mark.it("Can set a simple reported property") + @pytest.mark.parametrize(*twin_enabled_and_disabled) + @pytest.mark.quicktest_suite + async def test_sends_simple_reported_patch( + self, leak_tracker, service_helper, session, twin_enabled, random_reported_props ): leak_tracker.set_initial_object_list() - assert client.connected - dropper.drop_outgoing() + async with session: + # Enable twin responses if necessary + assert session._mqtt_client._twin_responses_enabled is False + if twin_enabled: + await session.get_twin() + assert session._mqtt_client._twin_responses_enabled is True - send_task = asyncio.ensure_future( - client.patch_twin_reported_properties(random_reported_props) - ) - while client.connected: - await asyncio.sleep(1) + # patch properties + await session.update_reported_properties(random_reported_props) - assert not send_task.done() + assert session._mqtt_client._twin_responses_enabled is True - dropper.restore_all() - while not client.connected: - await asyncio.sleep(1) + # wait for patch to arrive at service and verify + received_patch = await service_helper.get_next_reported_patch_arrival() + assert ( + received_patch[const.REPORTED][const.TEST_CONTENT] + == random_reported_props[const.TEST_CONTENT] + ) - await send_task + # get twin from the service and verify content + twin = await session.get_twin() + assert ( + twin[const.REPORTED][const.TEST_CONTENT] + == random_reported_props[const.TEST_CONTENT] + ) - received_patch = await service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) + leak_tracker.check_for_leaks() + + @pytest.mark.it("Raises correct exception for un-serializable patch") + @pytest.mark.parametrize(*twin_enabled_and_disabled) + async def test_bad_reported_patch_raises(self, leak_tracker, session, twin_enabled): + leak_tracker.set_initial_object_list() + + async with session: + # Enable twin responses if necessary + assert session._mqtt_client._twin_responses_enabled is False + if twin_enabled: + await session.get_twin() + assert session._mqtt_client._twin_responses_enabled is True + + # There's no way to serialize a function. + def thing_that_cant_serialize(): + pass + + with pytest.raises(TypeError): + await session.update_reported_properties(thing_that_cant_serialize) + + assert session.connected is False - # TODO: investigate leak - # leak_tracker.check_for_leaks() + leak_tracker.check_for_leaks() - @pytest.mark.it("Updates reported properties if connection rejects send") - async def test_updates_reported_if_reject_before_sending( - self, client, random_reported_props, dropper, service_helper, leak_tracker + @pytest.mark.it("Can clear a reported property") + @pytest.mark.parametrize(*twin_enabled_and_disabled) + @pytest.mark.quicktest_suite + async def test_clear_property( + self, leak_tracker, service_helper, session, twin_enabled, random_reported_props ): leak_tracker.set_initial_object_list() - assert client.connected - dropper.reject_outgoing() + async with session: + # Enable twin responses if necessary + assert session._mqtt_client._twin_responses_enabled is False + if twin_enabled: + await session.get_twin() + assert session._mqtt_client._twin_responses_enabled is True + + # patch properties and verify that the service received the patch + await session.update_reported_properties(random_reported_props) + received_patch = await service_helper.get_next_reported_patch_arrival() + assert ( + received_patch[const.REPORTED][const.TEST_CONTENT] + == random_reported_props[const.TEST_CONTENT] + ) + + # send a patch clearing properties and verify that the service received that patch + await session.update_reported_properties(reset_reported_props) + received_patch = await service_helper.get_next_reported_patch_arrival() + assert ( + received_patch[const.REPORTED][const.TEST_CONTENT] + == reset_reported_props[const.TEST_CONTENT] + ) + + # get the twin and verify that the properties are no longer part of the twin + twin = await session.get_twin() + assert const.TEST_CONTENT not in twin[const.REPORTED] + + assert session.connected is False - send_task = asyncio.ensure_future( - client.patch_twin_reported_properties(random_reported_props) - ) - while client.connected: - await asyncio.sleep(1) + leak_tracker.check_for_leaks() - assert not send_task.done() + @pytest.mark.it("Raises SessionError if there is no connection") + @pytest.mark.parametrize(*twin_enabled_and_disabled) + @pytest.mark.quicktest_suite + async def test_no_connection_raises_error( + self, leak_tracker, session, random_reported_props, twin_enabled + ): + leak_tracker.set_initial_object_list() - dropper.restore_all() - while not client.connected: - await asyncio.sleep(1) + # Enable twin responses if necessary + assert session._mqtt_client._twin_responses_enabled is False + if twin_enabled: + async with session: + await session.get_twin() - await send_task + assert session._mqtt_client._twin_responses_enabled is True + assert session.connected is False - received_patch = await service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) + with pytest.raises(SessionError): + await session.update_reported_properties(random_reported_props) + assert session.connected is False - # TODO: investigate leak - # leak_tracker.check_for_leaks() + leak_tracker.check_for_leaks() @pytest.mark.describe("Client Desired Properties") @@ -191,36 +276,147 @@ class TestDesiredProperties(object): @pytest.mark.it("Receives a patch for a simple desired property") @pytest.mark.quicktest_suite async def test_receives_simple_desired_patch( - self, client, event_loop, service_helper, leak_tracker + self, event_loop, leak_tracker, service_helper, session ): + random_dict = get_random_dict() leak_tracker.set_initial_object_list() - received_patch = None - received = asyncio.Event() + # Make a task to pull incoming patches from the generator and pyt + # them into a queue. + # In py310, anext can do the same thing, but we need to support older + # versions. + + queue = asyncio.Queue() + registered = asyncio.Event() + + async def listener(sess): + try: + async with sess.desired_property_updates() as patches: + # signal that we're registered + registered.set() + async for patch in patches: + await queue.put(patch) + except asyncio.CancelledError: + # this happens during shutdown. no need to log this. + raise + except BaseException: + # Without this line, exceptions get silently ignored until + # we await the listener task. + logger.error("Exception", exc_info=True) + raise + + async with session: + listener_task = asyncio.create_task(listener(session)) + await registered.wait() + + await service_helper.set_desired_properties( + {const.TEST_CONTENT: random_dict}, + ) + + received_patch = await queue.get() + assert received_patch[const.TEST_CONTENT] == random_dict + + twin = await session.get_twin() + assert twin[const.DESIRED][const.TEST_CONTENT] == random_dict + + assert session.connected is False + + # make sure our listener ended with an error when we disconnected. + logger.info("Waiting for listener_task to complete") + with pytest.raises(asyncio.CancelledError): + await listener_task + logger.info("Done waiting for listener_task") + + leak_tracker.check_for_leaks() + + @pytest.mark.it("Receives a patch for a simple desired property entering session context twice") + @pytest.mark.quicktest_suite + async def test_receives_simple_desired_patch_enter_session_twice( + self, event_loop, leak_tracker, service_helper, session + ): + random_dict = get_random_dict() + leak_tracker.set_initial_object_list() - async def handle_on_patch_received(patch): - nonlocal received_patch, received - print("received {}".format(patch)) - received_patch = patch - event_loop.call_soon_threadsafe(received.set) + # Make a task to pull incoming patches from the generator and pyt + # them into a queue. + # In py310, anext can do the same thing, but we need to support older + # versions. + + queue = asyncio.Queue() + registered = asyncio.Event() + + async def listener(sess): + try: + # This `async with` is the only difference from the previous test. + async with sess: + async with sess.desired_property_updates() as patches: + # signal that we're registered + registered.set() + async for patch in patches: + await queue.put(patch) + except asyncio.CancelledError: + # this happens during shutdown. no need to log this. + raise + except Exception: + # Without this line, exceptions get silently ignored until + # we await the listener task. + logger.error("Exception", exc_info=True) + raise + + async with session: + listener_task = asyncio.create_task(listener(session)) + await registered.wait() + + await service_helper.set_desired_properties( + {const.TEST_CONTENT: random_dict}, + ) + + received_patch = await queue.get() + assert received_patch[const.TEST_CONTENT] == random_dict + + twin = await session.get_twin() + assert twin[const.DESIRED][const.TEST_CONTENT] == random_dict + + assert session.connected is False + + # make sure our listener ended with an error when we disconnected. + logger.info("Waiting for listener_task to complete") + with pytest.raises(asyncio.CancelledError): + await listener_task + logger.info("Done waiting for listener_task") - client.on_twin_desired_properties_patch_received = handle_on_patch_received + leak_tracker.check_for_leaks() + @pytest.mark.skip("leaks") + @pytest.mark.it("Receives a patch for a simple desired property using anext") + @pytest.mark.quicktest_suite + @pytest.mark.skipif( + sys.version_info.major == 3 and sys.version_info.minor < 10, + reason="anext was not introduced until 3.10", + ) + async def test_receives_simple_desired_patch_using_anext( + self, event_loop, leak_tracker, service_helper, session + ): + leak_tracker.set_initial_object_list() random_dict = get_random_dict() - await service_helper.set_desired_properties( - {const.TEST_CONTENT: random_dict}, - ) - await asyncio.wait_for(received.wait(), 60) - assert received.is_set() + # Python 3.10 makes our lives easier because we can use anext() and treat the generator like a queue + + async with session: + async with session.desired_property_updates() as patches: + await service_helper.set_desired_properties( + {const.TEST_CONTENT: random_dict}, + ) - assert received_patch[const.TEST_CONTENT] == random_dict + received_patch = await anext(patches) # noqa: F821 + assert received_patch[const.TEST_CONTENT] == random_dict - twin = await client.get_twin() - assert twin[const.DESIRED][const.TEST_CONTENT] == random_dict + twin = await session.get_twin() + assert twin[const.DESIRED][const.TEST_CONTENT] == random_dict - # TODO: investigate leak - # leak_tracker.check_for_leaks() + assert session.connected is False + + leak_tracker.check_for_leaks() # TODO: etag tests, version tests diff --git a/tests/e2e/iothub_e2e/aio/test_twin_stress.py b/tests/e2e/iothub_e2e/aio/test_twin_stress.py deleted file mode 100644 index b873becf2..000000000 --- a/tests/e2e/iothub_e2e/aio/test_twin_stress.py +++ /dev/null @@ -1,388 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import asyncio -import pytest -import logging -import parametrize -import const -import dev_utils -from retry_async import retry_exponential_backoff_with_jitter - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - -pytestmark = pytest.mark.asyncio - - -@pytest.fixture -def toxic(): - pass - - -reset_reported_props = {const.TEST_CONTENT: None} - -call_with_retry = retry_exponential_backoff_with_jitter - - -def get_random_property_value(): - return dev_utils.get_random_string(100, True) - - -def wrap_as_reported_property(value, key=None): - if key: - return {const.TEST_CONTENT: {key: value}} - else: - return {const.TEST_CONTENT: value} - - -@pytest.mark.timeout(600) -@pytest.mark.stress -@pytest.mark.describe("Client Stress") -@pytest.mark.parametrize(*parametrize.auto_connect_disabled) -@pytest.mark.parametrize(*parametrize.connection_retry_disabled) -class TestTwinStress(object): - @pytest.mark.parametrize( - "iteration_count", [pytest.param(10, id="10 updates"), pytest.param(50, id="50 updates")] - ) - @pytest.mark.it("Can send continuous reported property updates, one-at-a-time") - async def test_stress_serial_reported_property_updates( - self, client, service_helper, toxic, iteration_count, leak_tracker - ): - """ - Send reported property updates, one at a time, and verify that each one - has been received at the service. Do not overlap these calls. - """ - leak_tracker.set_initial_object_list() - - leak_tracker.set_initial_object_list() - - await call_with_retry(client, client.patch_twin_reported_properties, reset_reported_props) - - for i in range(iteration_count): - logger.info("Iteration {} of {}".format(i, iteration_count)) - - # Update the reported property. - patch = wrap_as_reported_property(get_random_property_value()) - await call_with_retry(client, client.patch_twin_reported_properties, patch) - - # Wait for that reported property to arrive at the service. - received = False - while not received: - received_patch = await service_helper.get_next_reported_patch_arrival() - if ( - const.REPORTED in received_patch - and received_patch[const.REPORTED][const.TEST_CONTENT] - == patch[const.TEST_CONTENT] - ): - received = True - else: - logger.info( - "Wrong patch received. Expecting {}, got {}".format(patch, received_patch) - ) - - leak_tracker.check_for_leaks() - - @pytest.mark.parametrize( - "iteration_count, batch_size", - [ - pytest.param(20, 10, id="20 updates, 10 at a time"), - pytest.param(250, 25, id="250 updates, 25 at a time"), - ], - ) - @pytest.mark.it("Can send continuous overlapped reported property updates") - async def test_stress_parallel_reported_property_updates( - self, client, service_helper, toxic, iteration_count, batch_size, leak_tracker - ): - """ - Update reported properties with many overlapped calls. Work in batches - with `batch_size` overlapped calls in a batch. Verify that the updates arrive - at the service. - """ - leak_tracker.set_initial_object_list() - - leak_tracker.set_initial_object_list() - - await call_with_retry(client, client.patch_twin_reported_properties, reset_reported_props) - - for _ in range(0, iteration_count, batch_size): - props = { - "key_{}".format(k): get_random_property_value() for k in range(0, iteration_count) - } - - # Do overlapped calls to update `batch_size` properties. - tasks = [ - call_with_retry( - client, - client.patch_twin_reported_properties, - wrap_as_reported_property(props[key], key), - ) - for key in props.keys() - ] - results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, Exception): - raise result - - # wait for these properties to arrive at the service - count_received = 0 - while count_received < batch_size: - received_patch = await service_helper.get_next_reported_patch_arrival(timeout=60) - received_test_content = received_patch[const.REPORTED][const.TEST_CONTENT] or {} - logger.info("received {}".format(received_test_content)) - - if isinstance(received_test_content, dict): - # We check to make sure received_test_content is a dict because it may be - # a string left over from a previous test case. - # This can happen if if the tests are running fast and the reported - # property updates are being processed slowly. - for key in received_test_content.keys(): - logger.info("Received {} = {}".format(key, received_test_content[key])) - if key in props: - if received_test_content[key] == props[key]: - logger.info("Key {} received as expected.".format(key)) - # Set the value to None so we know that it's been received - props[key] = None - count_received += 1 - else: - logger.info( - "Ignoring unexpected value for key {}. Received = {}, expected = {}".format( - key, received_test_content[key], props[key] - ) - ) - - leak_tracker.check_for_leaks() - - @pytest.mark.parametrize( - "iteration_count", [pytest.param(10, id="10 updates"), pytest.param(50, id="50 updates")] - ) - @pytest.mark.it("Can receive continuous desired property updates that were sent one-at-a-time") - async def test_stress_serial_desired_property_updates( - self, client, service_helper, toxic, iteration_count, event_loop, leak_tracker - ): - """ - Update desired properties, one at a time, and verify that the desired property arrives - at the client before the next update. - """ - leak_tracker.set_initial_object_list() - - patches = asyncio.Queue() - - async def handle_on_patch_received(patch): - logger.info("received {}".format(patch)) - # marshal this back into our event loop so we can safely use the asyncio.queue - asyncio.run_coroutine_threadsafe(patches.put(patch), event_loop) - - client.on_twin_desired_properties_patch_received = handle_on_patch_received - - for i in range(iteration_count): - logger.info("Iteration {} of {}".format(i, iteration_count)) - - # update a single desired property - property_value = get_random_property_value() - await service_helper.set_desired_properties( - {const.TEST_CONTENT: property_value}, - ) - - # wait for the property update to arrive at the client - received_patch = await asyncio.wait_for(patches.get(), 60) - assert received_patch[const.TEST_CONTENT] == property_value - - leak_tracker.check_for_leaks() - - @pytest.mark.parametrize( - "iteration_count, batch_size", - [ - pytest.param(20, 10, id="20 updates, 10 at a time"), - pytest.param(250, 25, id="250 updates, 25 at a time"), - ], - ) - @pytest.mark.it( - "Can receive continuous desired property updates that may have been sent in parallel" - ) - async def test_stress_parallel_desired_property_updates( - self, client, service_helper, toxic, iteration_count, batch_size, event_loop, leak_tracker - ): - """ - Update desired properties in batches. Each batch updates `batch_size` properties, - with each property being updated in it's own `PATCH`. - """ - leak_tracker.set_initial_object_list() - - patches = asyncio.Queue() - - async def handle_on_patch_received(patch): - logger.info("received {}".format(patch)) - # use run_coroutine_threadsafe to marshal this back into our event - # loop so we can safely use the asyncio.queue - asyncio.run_coroutine_threadsafe(patches.put(patch), event_loop) - - client.on_twin_desired_properties_patch_received = handle_on_patch_received - - props = {"key_{}".format(k): None for k in range(0, batch_size)} - - await service_helper.set_desired_properties({const.TEST_CONTENT: None}) - - for _ in range(0, iteration_count, batch_size): - - # update `batch_size` properties, each with a call to `set_desired_properties` - props = {"key_{}".format(k): get_random_property_value() for k in range(0, batch_size)} - tasks = [ - service_helper.set_desired_properties({const.TEST_CONTENT: {key: props[key]}}) - for key in props.keys() - ] - results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, Exception): - raise result - - # Wait for those properties to arrive at the client - count_received = 0 - while count_received < batch_size: - received_patch = await asyncio.wait_for(patches.get(), 60) - received_test_content = received_patch[const.TEST_CONTENT] or {} - - for key in received_test_content: - logger.info("Received {} = {}".format(key, received_test_content[key])) - if key in props: - if received_test_content[key] == props[key]: - logger.info("Key {} received as expected.".format(key)) - # Set the value to None so we know that it's been received - props[key] = None - count_received += 1 - else: - logger.info( - "Ignoring unexpected value for key {}. Received = {}, expected = {}".format( - key, received_test_content[key], props[key] - ) - ) - - leak_tracker.check_for_leaks() - - @pytest.mark.parametrize( - "iteration_count", [pytest.param(10, id="10 updates"), pytest.param(50, id="50 updates")] - ) - @pytest.mark.it("Can continuously call get_twin and get valid property values") - async def test_stress_serial_get_twin_calls( - self, client, service_helper, toxic, iteration_count, leak_tracker - ): - """ - Call `get_twin` once-at-a-time to verify that updated properties show up. This test - calls `get_twin()` `iteration_count` times. Once a reported property shows up in the - twin, that property is updated to be verified in future `get_twin` calls. - """ - leak_tracker.set_initial_object_list() - - last_property_value = None - current_property_value = None - - for i in range(iteration_count): - logger.info("Iteration {} of {}".format(i, iteration_count)) - - # Set a reported property - if not current_property_value: - current_property_value = get_random_property_value() - logger.info("patching to {}".format(current_property_value)) - await call_with_retry( - client, - client.patch_twin_reported_properties, - wrap_as_reported_property(current_property_value), - ) - - # Call get_twin to verify that this property arrived. - # reported properties aren't immediately reflected in `get_twin` calls, - # so we have to account for retrieving old property values. - twin = await call_with_retry(client, client.get_twin) - logger.info("Got {}".format(twin[const.REPORTED][const.TEST_CONTENT])) - if twin[const.REPORTED][const.TEST_CONTENT] == current_property_value: - logger.info("it's a match.") - last_property_value = current_property_value - current_property_value = None - elif last_property_value: - # If it's not the current value, then it _must_ be the last value - # We can only verify this if we know what the old value was. - assert twin[const.REPORTED][const.TEST_CONTENT] == last_property_value - - assert last_property_value, "No patches with updated properties were received" - - leak_tracker.check_for_leaks() - - leak_tracker.check_for_leaks() - - @pytest.mark.parametrize( - "iteration_count, batch_size", - [ - pytest.param(20, 10, id="20 updates, 10 at a time"), - pytest.param(250, 25, id="250 updates, 25 at a time"), - pytest.param(1000, 50, id="1000 updates, 50 at a time"), - ], - ) - @pytest.mark.it("Can continuously make overlapped get_twin calls and get valid property values") - async def test_stress_parallel_get_twin_calls( - self, client, service_helper, toxic, iteration_count, batch_size, leak_tracker - ): - """ - Call `get_twin` many times, overlapped, to verify that updated properties show up. This test - calls `get_twin()` `iteration_count` times. Once a reported property shows up in the - twin, that property is updated to be verified in future `get_twin` calls. - """ - leak_tracker.set_initial_object_list() - - last_property_value = None - current_property_value = get_random_property_value() - - await call_with_retry( - client, - client.patch_twin_reported_properties, - wrap_as_reported_property(current_property_value), - ) - ready_to_test = False - - while not ready_to_test: - twin = await call_with_retry(client, client.get_twin) - if twin[const.REPORTED].get(const.TEST_CONTENT, "") == current_property_value: - logger.info("Initial value set") - ready_to_test = True - else: - logger.info("Waiting for initial value. Sleeping for 5") - await asyncio.sleep(5) - - for i in range(0, iteration_count, batch_size): - logger.info("Iteration {} of {}".format(i, iteration_count)) - - # Update the property if it's time to update - if not current_property_value: - current_property_value = get_random_property_value() - logger.info("patching to {}".format(current_property_value)) - await call_with_retry( - client, - client.patch_twin_reported_properties, - wrap_as_reported_property(current_property_value), - ) - - # Call `get_twin` many times overlapped and verify that we get either - # the old property value (if we know it), or the new property value. - tasks = [call_with_retry(client, client.get_twin) for _ in range(batch_size)] - results = await asyncio.gather(*tasks, return_exceptions=True) - got_a_match = False - - for result in results: - if isinstance(result, Exception): - raise result - - twin = result - logger.info("Got {}".format(twin[const.REPORTED][const.TEST_CONTENT])) - if twin[const.REPORTED][const.TEST_CONTENT] == current_property_value: - logger.info("it's a match.") - got_a_match = True - elif last_property_value: - # if it's not the current value, then it _must_ be the last value - assert twin[const.REPORTED][const.TEST_CONTENT] == last_property_value - - # Once we verify that `get_twin` returned the new property value, we set - # it to `None` so the next iteration of the loop can update this value. - if got_a_match: - last_property_value = current_property_value - current_property_value = None - - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/client_fixtures.py b/tests/e2e/iothub_e2e/client_fixtures.py index 752aeb1f7..5c83bf189 100644 --- a/tests/e2e/iothub_e2e/client_fixtures.py +++ b/tests/e2e/iothub_e2e/client_fixtures.py @@ -15,26 +15,6 @@ def module_id(device_identity): return None -@pytest.fixture(scope="function") -def connection_retry(request): - # let tests use @pytest.mark.connection_retry(x) to set connection_retry - marker = request.node.get_closest_marker("connection_retry") - if marker: - return marker.args[0] - else: - return True - - -@pytest.fixture(scope="function") -def auto_connect(request): - # let tests use @pytest.mark.auto_connect(x) to set auto_connect - marker = request.node.get_closest_marker("auto_connect") - if marker: - return marker.args[0] - else: - return True - - @pytest.fixture(scope="function") def websockets(): return test_config.config.transport == test_config.TRANSPORT_MQTT_WS @@ -62,10 +42,8 @@ def sastoken_ttl(request): @pytest.fixture(scope="function") -def client_kwargs(auto_connect, connection_retry, websockets, keep_alive, sastoken_ttl): +def client_kwargs(websockets, keep_alive, sastoken_ttl): kwargs = {} - kwargs["auto_connect"] = auto_connect - kwargs["connection_retry"] = connection_retry kwargs["websockets"] = websockets if keep_alive is not None: kwargs["keep_alive"] = keep_alive diff --git a/tests/e2e/iothub_e2e/conftest.py b/tests/e2e/iothub_e2e/conftest.py index 0d82d1523..f059a19b9 100644 --- a/tests/e2e/iothub_e2e/conftest.py +++ b/tests/e2e/iothub_e2e/conftest.py @@ -14,8 +14,6 @@ from drop_fixtures import dropper # noqa: F401 from client_fixtures import ( # noqa: F401 client_kwargs, - auto_connect, - connection_retry, websockets, device_id, module_id, @@ -32,6 +30,7 @@ logging.getLogger("paho").setLevel(level=logging.DEBUG) logging.getLogger("azure.iot").setLevel(level=logging.DEBUG) + logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) @@ -189,10 +188,7 @@ def pytest_runtest_setup(item): # tests that use iptables need to be skipped on Windows if is_windows(): - for x in item.iter_markers("uses_iptables"): - pytest.skip("test uses iptables") - return - for x in item.iter_markers("dropped_connection"): + if "dropper" in item.fixturenames: pytest.skip("test uses iptables") return diff --git a/tests/e2e/iothub_e2e/parametrize.py b/tests/e2e/iothub_e2e/parametrize.py index e830b96c7..41f937ac1 100644 --- a/tests/e2e/iothub_e2e/parametrize.py +++ b/tests/e2e/iothub_e2e/parametrize.py @@ -17,47 +17,3 @@ pytest.param(False, False, id="with no request payload and no response payload"), ], ] - -connection_retry_disabled_and_enabled = [ - "connection_retry", - [ - pytest.param(True, id="connection_retry enabled"), - pytest.param(False, id="connection_retry disabled"), - ], -] - -connection_retry_enabled = [ - "connection_retry", - [ - pytest.param(True, id="connection_retry enabled"), - ], -] - -connection_retry_disabled = [ - "connection_retry", - [ - pytest.param(False, id="connection_retry disabled"), - ], -] - -auto_connect_disabled_and_enabled = [ - "auto_connect", - [ - pytest.param(True, id="auto_connect enabled"), - pytest.param(False, id="auto_connect disabled"), - ], -] - -auto_connect_enabled = [ - "auto_connect", - [ - pytest.param(True, id="auto_connect enabled"), - ], -] - -auto_connect_disabled = [ - "auto_connect", - [ - pytest.param(False, id="auto_connect disabled"), - ], -] diff --git a/tests/e2e/iothub_e2e/pytest.ini b/tests/e2e/iothub_e2e/pytest.ini index 0094ca0ed..0db2a1b8d 100644 --- a/tests/e2e/iothub_e2e/pytest.ini +++ b/tests/e2e/iothub_e2e/pytest.ini @@ -1,23 +1,16 @@ [pytest] -timeout=120 +timeout=300 testdox_format=plaintext junit_logging=all junit_family=xunit2 junit_log_passing_tests=True asyncio_mode=auto -# --force-testdox to always use testdox format, even when redirecting to file addopts= --testdox --force-testdox --strict-markers - -m "not stress" norecursedirs=__pycache__, *.egg-info markers= - dropped_connection: includes tests that simplate dropped network connections. - uses_iptables: tests that use iptables. skipped on Windows. quicktest_suite: tests which are part of the quick-test suite. - stress: run stress tests keep_alive: use to pass custom keep_alive from tests into fixtures sastoken_ttl: use to pass custom sastoken_ttl from tests into fixtures - connection_retry: use to pass custom connection_retry from tests into fixtures - auto_connect: use to pass custom auto_connect from tests into fixtures diff --git a/tests/e2e/iothub_e2e/sync/conftest.py b/tests/e2e/iothub_e2e/sync/conftest.py deleted file mode 100644 index d04d9c5cd..000000000 --- a/tests/e2e/iothub_e2e/sync/conftest.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import time -from dev_utils import test_env, ServiceHelperSync -import logging -import datetime -from utils import create_client_object -from azure.iot.device.iothub import IoTHubDeviceClient, IoTHubModuleClient - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - - -@pytest.fixture(scope="function") -def brand_new_client(device_identity, client_kwargs, service_helper, device_id, module_id): - service_helper.set_identity(device_id, module_id) - - # Keep this here. It is useful to see this info inside the inside devops pipeline test failures. - logger.info( - "Connecting device_id={}, module_id={}, to hub={} at {} (UTC)".format( - device_id, module_id, test_env.IOTHUB_HOSTNAME, datetime.datetime.utcnow() - ) - ) - - client = create_client_object( - device_identity, client_kwargs, IoTHubDeviceClient, IoTHubModuleClient - ) - - yield client - - logger.info("---------------------------------------") - logger.info("test is complete. Shutting down client") - logger.info("---------------------------------------") - - client.shutdown() - - logger.info("-------------------------------------------") - logger.info("test is complete. client shutdown complete") - logger.info("-------------------------------------------") - - -@pytest.fixture(scope="function") -def client(brand_new_client): - client = brand_new_client - - client.connect() - - yield client - - -@pytest.fixture(scope="session") -def service_helper(): - service_helper = ServiceHelperSync( - iothub_connection_string=test_env.IOTHUB_CONNECTION_STRING, - eventhub_connection_string=test_env.EVENTHUB_CONNECTION_STRING, - eventhub_consumer_group=test_env.EVENTHUB_CONSUMER_GROUP, - ) - time.sleep(3) - yield service_helper - - logger.info("----------------------------") - logger.info("shutting down service_helper") - logger.info("----------------------------") - - service_helper.shutdown() - - logger.info("---------------------------------") - logger.info("service helper shut down complete") - logger.info("---------------------------------") diff --git a/tests/e2e/iothub_e2e/sync/test_sync_c2d.py b/tests/e2e/iothub_e2e/sync/test_sync_c2d.py deleted file mode 100644 index 94cfedac3..000000000 --- a/tests/e2e/iothub_e2e/sync/test_sync_c2d.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import logging -import json -import threading -from dev_utils import get_random_dict - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - -# TODO: add tests for various application properties -# TODO: is there a way to call send_c2d so it arrives as an object rather than a JSON string? - - -@pytest.mark.describe("Client C2d") -class TestReceiveC2d(object): - @pytest.mark.it("Can receive C2D") - @pytest.mark.quicktest_suite - def test_sync_receive_c2d(self, client, service_helper, leak_tracker): - leak_tracker.set_initial_object_list() - - message = json.dumps(get_random_dict()) - - received_message = None - received = threading.Event() - - def handle_on_message_received(message): - nonlocal received_message, received - logger.info("received {}".format(message)) - received_message = message - received.set() - - client.on_message_received = handle_on_message_received - - service_helper.send_c2d(message, {}) - - received.wait(timeout=60) - assert received.is_set() - - assert received_message.data.decode("utf-8") == message - - received_message = None # so this isn't tagged as a leak - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/sync/test_sync_connect_disconnect.py b/tests/e2e/iothub_e2e/sync/test_sync_connect_disconnect.py deleted file mode 100644 index 927681170..000000000 --- a/tests/e2e/iothub_e2e/sync/test_sync_connect_disconnect.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import logging -import time -import threading -import parametrize - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - - -@pytest.mark.describe("Client object") -class TestConnectDisconnect(object): - @pytest.mark.it("Can disconnect and reconnect") - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - @pytest.mark.quicktest_suite - def test_sync_connect_disconnect(self, brand_new_client, leak_tracker): - leak_tracker.set_initial_object_list() - - client = brand_new_client - - client.connect() - assert client.connected - - client.disconnect() - assert not client.connected - - client.connect() - assert client.connected - - leak_tracker.check_for_leaks() - - @pytest.mark.it( - "Can do a manual connect in the `on_connection_state_change` call that is notifying the user about a disconnect." - ) - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - # see "This assert fails because of initial and secondary disconnects" below - @pytest.mark.skip(reason="two stage disconnect causes assertion in test code") - def test_sync_connect_in_the_middle_of_disconnect( - self, brand_new_client, service_helper, random_message, leak_tracker - ): - """ - Explanation: People will call `connect` inside `on_connection_state_change` handlers. - We have to make sure that we can handle this without getting stuck in a bad state. - """ - leak_tracker.set_initial_object_list() - - client = brand_new_client - assert client - - reconnected_event = threading.Event() - - def handle_on_connection_state_change(): - nonlocal reconnected_event - - if client.connected: - logger.info("handle_on_connection_state_change connected. nothing to do") - else: - logger.info("handle_on_connection_state_change disconnected. reconnecting.") - client.connect() - assert client.connected - reconnected_event.set() - logger.info("reconnect event set") - - client.on_connection_state_change = handle_on_connection_state_change - - # connect - client.connect() - assert client.connected - - # disconnect. - reconnected_event.clear() - logger.info("Calling client.disconnect.") - client.disconnect() - - # wait for handle_on_connection_state_change to reconnect - logger.info("waiting for reconnect_event to be set.") - reconnected_event.wait() - - logger.info( - "reconnect_event.wait() returned. client.connected={}".format(client.connected) - ) - # This assert fails because of initial and secondary disconnects - assert client.connected - - # sleep a while and make sure that we're still connected. - time.sleep(3) - assert client.connected - - # finally, send a message to makes reu we're _really_ connected - client.send_message(random_message) - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert event - - leak_tracker.check_for_leaks() - - @pytest.mark.it( - "Can do a manual disconnect in the `on_connection_state_change` call that is notifying the user about a connect." - ) - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - @pytest.mark.parametrize( - "first_connect", - [pytest.param(True, id="First connection"), pytest.param(False, id="Second connection")], - ) - def test_sync_disconnect_in_the_middle_of_connect( - self, brand_new_client, service_helper, random_message, first_connect, leak_tracker - ): - """ - Explanation: This is the inverse of `test_connect_in_the_middle_of_disconnect`. This is - less likely to be a user scenario, but it lets us test with unusual-but-specific timing - on the call to `disconnect`. - """ - leak_tracker.set_initial_object_list() - - client = brand_new_client - assert client - disconnect_on_next_connect_event = False - - disconnected_event = threading.Event() - - def handle_on_connection_state_change(): - nonlocal disconnected_event - - if client.connected: - if disconnect_on_next_connect_event: - logger.info("connected. disconnecting now") - client.disconnect() - disconnected_event.set() - else: - logger.info("connected, but nothing to do") - else: - logger.info("disconnected. nothing to do") - - client.on_connection_state_change = handle_on_connection_state_change - - if not first_connect: - # connect - client.connect() - assert client.connected - - # disconnect. - client.disconnect() - - assert not client.connected - - # now, connect (maybe for the second time), and disconnect inside the on_connected handler - disconnect_on_next_connect_event = True - disconnected_event.clear() - client.connect() - - # and wait for us to disconnect - disconnected_event.wait() - assert not client.connected - - # sleep a while and make sure that we're still disconnected. - time.sleep(3) - assert not client.connected - - # finally, connect and make sure we can send a message - disconnect_on_next_connect_event = False - client.connect() - assert client.connected - - client.send_message(random_message) - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert event - - leak_tracker.check_for_leaks() - - -@pytest.mark.dropped_connection -@pytest.mark.describe("Client object with dropped connection") -@pytest.mark.keep_alive(5) -class TestConnectDisconnectDroppedConnection(object): - @pytest.mark.it("disconnects when network drops all outgoing packets") - def test_sync_disconnect_on_drop_outgoing(self, client, dropper, leak_tracker): - """ - This test verifies that the client will disconnect (eventually) if the network starts - dropping packets - """ - leak_tracker.set_initial_object_list() - - client.connect() - assert client.connected - dropper.drop_outgoing() - - while client.connected: - time.sleep(1) - - # we've passed the test. Now wait to reconnect before we check for leaks. Otherwise we - # have a pending ConnectOperation floating around and this would get tagged as a leak. - dropper.restore_all() - while not client.connected: - time.sleep(1) - - leak_tracker.check_for_leaks() - - @pytest.mark.it("disconnects when network rejects all outgoing packets") - def test_sync_disconnect_on_reject_outgoing(self, client, dropper, leak_tracker): - """ - This test verifies that the client will disconnect (eventually) if the network starts - rejecting packets - """ - leak_tracker.set_initial_object_list() - - client.connect() - assert client.connected - dropper.reject_outgoing() - - while client.connected: - time.sleep(1) - - # we've passed the test. Now wait to reconnect before we check for leaks. Otherwise we - # have a pending ConnectOperation floating around and this would get tagged as a leak. - dropper.restore_all() - while not client.connected: - time.sleep(1) - - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/sync/test_sync_infrastructure.py b/tests/e2e/iothub_e2e/sync/test_sync_infrastructure.py deleted file mode 100644 index 0919ad392..000000000 --- a/tests/e2e/iothub_e2e/sync/test_sync_infrastructure.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import uuid - - -@pytest.mark.describe("ServiceHelper object") -class TestServiceHelper(object): - @pytest.mark.it("returns None when wait_for_event_arrival times out") - def test_sync_wait_for_event_arrival(self, client, random_message, service_helper): - - event = service_helper.wait_for_eventhub_arrival(uuid.uuid4(), timeout=2) - assert event is None diff --git a/tests/e2e/iothub_e2e/sync/test_sync_methods.py b/tests/e2e/iothub_e2e/sync/test_sync_methods.py deleted file mode 100644 index 97ed652a2..000000000 --- a/tests/e2e/iothub_e2e/sync/test_sync_methods.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import logging -import time -from dev_utils import get_random_dict -import parametrize -from azure.iot.device.iothub import MethodResponse - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - - -@pytest.fixture -def method_name(): - return "this_is_my_method_name" - - -@pytest.fixture -def method_response_status(): - return 299 - - -@pytest.mark.describe("Client methods") -class TestMethods(object): - @pytest.mark.it("Can handle a simple direct method call") - @pytest.mark.parametrize(*parametrize.all_method_payload_options) - def test_sync_handle_method_call( - self, - client, - method_name, - method_response_status, - include_request_payload, - include_response_payload, - service_helper, - leak_tracker, - ): - leak_tracker.set_initial_object_list() - - if include_request_payload: - request_payload = get_random_dict() - else: - request_payload = None - - if include_response_payload: - response_payload = get_random_dict() - else: - response_payload = None - - def handle_on_method_request_received(request): - nonlocal actual_request - logger.info("Method request for {} received".format(request.name)) - actual_request = request - logger.info("Sending response") - client.send_method_response( - MethodResponse.create_from_method_request( - request, method_response_status, response_payload - ) - ) - - client.on_method_request_received = handle_on_method_request_received - time.sleep(1) # wait for subscribe, etc, to complete - - # invoke the method call - method_response = service_helper.invoke_method(method_name, request_payload) - - # verify that the method request arrived correctly - assert actual_request.name == method_name - if request_payload: - assert actual_request.payload == request_payload - else: - assert not actual_request.payload - - # and make sure the response came back successfully - assert method_response.status == method_response_status - assert method_response.payload == response_payload - - actual_request = None # so this isn't tagged as a leak - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/sync/test_sync_sas_renewal.py b/tests/e2e/iothub_e2e/sync/test_sync_sas_renewal.py deleted file mode 100644 index 2a24167ec..000000000 --- a/tests/e2e/iothub_e2e/sync/test_sync_sas_renewal.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import json -import logging -import threading -import test_config -import parametrize - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - - -@pytest.mark.skipif( - test_config.config.auth not in test_config.AUTH_WITH_RENEWING_TOKEN, - reason="{} auth does not support token renewal".format(test_config.config.auth), -) -@pytest.mark.describe("Client sas renewal code") -@pytest.mark.sastoken_ttl(130) # should renew after 10 seconds -class TestSasRenewal(object): - @pytest.mark.it("Renews and reconnects before expiry") - @pytest.mark.parametrize(*parametrize.connection_retry_disabled_and_enabled) - @pytest.mark.parametrize(*parametrize.auto_connect_disabled_and_enabled) - def test_sync_sas_renews(self, client, service_helper, random_message, leak_tracker): - leak_tracker.set_initial_object_list() - - connected_event = threading.Event() - disconnected_event = threading.Event() - - token_object = client._mqtt_pipeline.pipeline_configuration.sastoken - token_at_connect_time = None - - def handle_on_connection_state_change(): - nonlocal token_at_connect_time - logger.info("handle_on_connection_state_change: {}".format(client.connected)) - if client.connected: - token_at_connect_time = str(token_object) - logger.info("saving token: {}".format(token_at_connect_time)) - - connected_event.set() - else: - disconnected_event.set() - - client.on_connection_state_change = handle_on_connection_state_change - - # setting on_connection_state_change seems to have the side effect of - # calling handle_on_connection_state_change once with the initial value. - # Wait for one disconnect/reconnect cycle so we can get past it. - connected_event.wait() - - # OK, we're ready to test. wait for the renewal - token_before_connect = str(token_object) - - disconnected_event.clear() - connected_event.clear() - - logger.info("Waiting for client to disconnect") - disconnected_event.wait() - logger.info("Waiting for client to reconnect") - connected_event.wait() - logger.info("Client reconnected") - - # Finally verify that our token changed. - logger.info("token now = {}".format(str(token_object))) - logger.info("token at_connect = {}".format(str(token_at_connect_time))) - logger.info("token before_connect = {}".format(str(token_before_connect))) - assert str(token_object) == token_at_connect_time - assert not token_before_connect == token_at_connect_time - - # and verify that we can send - client.send_message(random_message) - - # and verify that the message arrived at the service - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/sync/test_sync_send_message.py b/tests/e2e/iothub_e2e/sync/test_sync_send_message.py deleted file mode 100644 index 78136019b..000000000 --- a/tests/e2e/iothub_e2e/sync/test_sync_send_message.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import logging -import json -import time -import dev_utils -from azure.iot.device.exceptions import OperationCancelled, ClientError - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - - -@pytest.mark.describe("Client send_message method") -class TestSendMessage(object): - @pytest.mark.it("Can send a simple message") - @pytest.mark.quicktest_suite - def test_sync_send_message_simple(self, client, random_message, service_helper, leak_tracker): - leak_tracker.set_initial_object_list() - - client.send_message(random_message) - - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Connects the transport if necessary") - @pytest.mark.quicktest_suite - def test_sync_connect_if_necessary(self, client, random_message, service_helper, leak_tracker): - leak_tracker.set_initial_object_list() - - client.disconnect() - assert not client.connected - - client.send_message(random_message) - assert client.connected - - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Raises correct exception for un-serializable payload") - def test_sync_bad_payload_raises(self, client, leak_tracker): - leak_tracker.set_initial_object_list() - - # There's no way to serialize a function. - def thing_that_cant_serialize(): - pass - - with pytest.raises(ClientError) as e_info: - client.send_message(thing_that_cant_serialize) - assert isinstance(e_info.value.__cause__, TypeError) - - # TODO; investigate this leak - # leak_tracker.check_for_leaks() - - @pytest.mark.it("Can send a JSON-formatted string that isn't wrapped in a Message object") - def test_sync_sends_json_string(self, client, service_helper, leak_tracker): - leak_tracker.set_initial_object_list() - - message = json.dumps(dev_utils.get_random_dict()) - - client.send_message(message) - - event = service_helper.wait_for_eventhub_arrival(None) - assert json.dumps(event.message_body) == message - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Can send a random string that isn't wrapped in a Message object") - def test_sync_sends_random_string(self, client, service_helper, leak_tracker): - leak_tracker.set_initial_object_list() - - message = dev_utils.get_random_string(16) - - client.send_message(message) - - event = service_helper.wait_for_eventhub_arrival(None) - assert event.message_body == message - - leak_tracker.check_for_leaks() - - -@pytest.mark.dropped_connection -@pytest.mark.describe("Client send_message method with dropped connections") -@pytest.mark.keep_alive(5) -class TestSendMessageDroppedConnection(object): - @pytest.mark.it("Sends if connection drops before sending") - @pytest.mark.uses_iptables - def test_sync_sends_if_drop_before_sending( - self, client, random_message, dropper, service_helper, executor, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - send_task = executor.submit(client.send_message, random_message) - - while client.connected: - time.sleep(1) - - assert not send_task.done() - - dropper.restore_all() - while not client.connected: - time.sleep(1) - - send_task.result() - - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - random_message = None # so this doesn't get tagged as a leak - leak_tracker.check_for_leaks() - - @pytest.mark.it("Sends if connection rejects send") - @pytest.mark.uses_iptables - def test_sync_sends_if_reject_before_sending( - self, client, random_message, dropper, service_helper, executor, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.reject_outgoing() - send_task = executor.submit(client.send_message, random_message) - - while client.connected: - time.sleep(1) - - assert not send_task.done() - - dropper.restore_all() - while not client.connected: - time.sleep(1) - - send_task.result() - - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - random_message = None # so this doesn't get tagged as a leak - leak_tracker.check_for_leaks() - - -@pytest.mark.describe("Client send_message with reconnect disabled") -@pytest.mark.keep_alive(5) -@pytest.mark.connection_retry(False) -class TestSendMessageRetryDisabled(object): - @pytest.fixture(scope="function", autouse=True) - def reconnect_after_test(self, dropper, client): - yield - dropper.restore_all() - client.connect() - assert client.connected - - @pytest.mark.it("Can send a simple message") - def test_sync_send_message_simple_with_retry_disabled( - self, client, random_message, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - client.send_message(random_message) - - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Automatically connects if transport manually disconnected before sending") - def test_sync_connect_if_necessary_with_retry_disabled( - self, client, random_message, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - client.disconnect() - assert not client.connected - - client.send_message(random_message) - assert client.connected - - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Automatically connects if transport automatically disconnected before sending") - @pytest.mark.uses_iptables - def test_sync_connects_after_automatic_disconnect_with_retry_disabled( - self, client, random_message, dropper, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - while client.connected: - time.sleep(1) - - assert not client.connected - dropper.restore_all() - client.send_message(random_message) - assert client.connected - - event = service_helper.wait_for_eventhub_arrival(random_message.message_id) - assert json.dumps(event.message_body) == random_message.data - - leak_tracker.check_for_leaks() - - @pytest.mark.it("Fails if connection disconnects before sending") - @pytest.mark.uses_iptables - def test_sync_fails_if_disconnect_before_sending_with_retry_disabled( - self, client, random_message, dropper, executor, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - send_task = executor.submit(client.send_message, random_message) - - while client.connected: - time.sleep(1) - - with pytest.raises(OperationCancelled): - send_task.result() - - random_message = None # So this doesn't get tagged as a leak - # TODO: investigate this leak - # leak_tracker.check_for_leaks() - - @pytest.mark.it("Fails if connection drops before sending") - @pytest.mark.uses_iptables - def test_sync_fails_if_drop_before_sending_with_retry_disabled( - self, client, random_message, dropper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - - dropper.drop_outgoing() - with pytest.raises(OperationCancelled): - client.send_message(random_message) - - assert not client.connected - - random_message = None # So this doesn't get tagged as a leak - # TODO: investigate this leak - # leak_tracker.check_for_leaks() diff --git a/tests/e2e/iothub_e2e/sync/test_sync_twin.py b/tests/e2e/iothub_e2e/sync/test_sync_twin.py deleted file mode 100644 index 53a0647b9..000000000 --- a/tests/e2e/iothub_e2e/sync/test_sync_twin.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -import pytest -import logging -import time -import const -import queue -from dev_utils import get_random_dict -from azure.iot.device.exceptions import ClientError - -logger = logging.getLogger(__name__) -logger.setLevel(level=logging.INFO) - - -# TODO: tests with drop_incoming and reject_incoming - -reset_reported_props = {const.TEST_CONTENT: None} - - -@pytest.mark.describe("Client Reported Properties") -class TestReportedProperties(object): - @pytest.mark.it("Can set a simple reported property") - @pytest.mark.quicktest_suite - def test_sync_sends_simple_reported_patch( - self, client, random_reported_props, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - # patch properties - client.patch_twin_reported_properties(random_reported_props) - - # wait for patch to arrive at service and verify - received_patch = service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) - - # get twin from the service and verify content - twin = client.get_twin() - assert twin[const.REPORTED][const.TEST_CONTENT] == random_reported_props[const.TEST_CONTENT] - - @pytest.mark.it("Raises correct exception for un-serializable patch") - def test_sync_bad_reported_patch_raises(self, client, leak_tracker): - leak_tracker.set_initial_object_list() - - # There's no way to serialize a function. - def thing_that_cant_serialize(): - pass - - with pytest.raises(ClientError) as e_info: - client.patch_twin_reported_properties(thing_that_cant_serialize) - assert isinstance(e_info.value.__cause__, TypeError) - - @pytest.mark.it("Can clear a reported property") - @pytest.mark.quicktest_suite - def test_sync_clear_property(self, client, random_reported_props, service_helper, leak_tracker): - leak_tracker.set_initial_object_list() - - # patch properties and verify that the service received the patch - client.patch_twin_reported_properties(random_reported_props) - received_patch = service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) - - # send a patch clearing properties and verify that the service received that patch - client.patch_twin_reported_properties(reset_reported_props) - received_patch = service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == reset_reported_props[const.TEST_CONTENT] - ) - - # get the twin and verify that the properties are no longer part of the twin - twin = client.get_twin() - assert const.TEST_CONTENT not in twin[const.REPORTED] - - @pytest.mark.it("Connects the transport if necessary") - @pytest.mark.quicktest_suite - def test_sync_patch_reported_connect_if_necessary( - self, client, random_reported_props, service_helper, leak_tracker - ): - leak_tracker.set_initial_object_list() - - client.disconnect() - - assert not client.connected - client.patch_twin_reported_properties(random_reported_props) - assert client.connected - - received_patch = service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) - - twin = client.get_twin() - assert twin[const.REPORTED][const.TEST_CONTENT] == random_reported_props[const.TEST_CONTENT] - - leak_tracker.check_for_leaks() - - -@pytest.mark.dropped_connection -@pytest.mark.describe("Client Reported Properties with dropped connection") -@pytest.mark.keep_alive(5) -class TestReportedPropertiesDroppedConnection(object): - - # TODO: split drop tests between first and second patches - - @pytest.mark.it("Updates reported properties if connection drops before sending") - def test_sync_updates_reported_if_drop_before_sending( - self, client, random_reported_props, dropper, service_helper, executor, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - dropper.drop_outgoing() - - send_task = executor.submit(client.patch_twin_reported_properties, random_reported_props) - while client.connected: - time.sleep(1) - - assert not send_task.done() - - dropper.restore_all() - while not client.connected: - time.sleep(1) - - send_task.result() - - received_patch = service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) - - # TODO: investigate leak - # leak_tracker.check_for_leaks() - - @pytest.mark.it("Updates reported properties if connection rejects send") - def test_sync_updates_reported_if_reject_before_sending( - self, client, random_reported_props, dropper, service_helper, executor, leak_tracker - ): - leak_tracker.set_initial_object_list() - - assert client.connected - dropper.reject_outgoing() - - send_task = executor.submit(client.patch_twin_reported_properties, random_reported_props) - while client.connected: - time.sleep(1) - - assert not send_task.done() - - dropper.restore_all() - while not client.connected: - time.sleep(1) - - send_task.result() - - received_patch = service_helper.get_next_reported_patch_arrival() - assert ( - received_patch[const.REPORTED][const.TEST_CONTENT] - == random_reported_props[const.TEST_CONTENT] - ) - - # TODO: investigate leak - # leak_tracker.check_for_leaks() - - -@pytest.mark.describe("Client Desired Properties") -class TestDesiredProperties(object): - @pytest.mark.it("Receives a patch for a simple desired property") - @pytest.mark.quicktest_suite - def test_sync_receives_simple_desired_patch(self, client, service_helper, leak_tracker): - received_patches = queue.Queue() - leak_tracker.set_initial_object_list() - - def handle_on_patch_received(patch): - nonlocal received_patches - print("received {}".format(patch)) - received_patches.put(patch) - - client.on_twin_desired_properties_patch_received = handle_on_patch_received - - # erase all old desired properties. Otherwise our random dict will only - # be part of the twin we get when we call `get_twin` below (because of - # properties from previous tests). - service_helper.set_desired_properties( - {const.TEST_CONTENT: None}, - ) - - random_dict = get_random_dict() - service_helper.set_desired_properties( - {const.TEST_CONTENT: random_dict}, - ) - - while True: - received_patch = received_patches.get(timeout=60) - - if received_patch[const.TEST_CONTENT] == random_dict: - twin = client.get_twin() - assert twin[const.DESIRED][const.TEST_CONTENT] == random_dict - break - - leak_tracker.check_for_leaks() - - -# TODO: etag tests, version tests diff --git a/tests/e2e/iothub_e2e/utils.py b/tests/e2e/iothub_e2e/utils.py index 27625a5d1..23d301687 100644 --- a/tests/e2e/iothub_e2e/utils.py +++ b/tests/e2e/iothub_e2e/utils.py @@ -5,7 +5,7 @@ from dev_utils import test_env import logging import sys -from azure.iot.device.iothub import Message +from azure.iot.device import Message, IoTHubSession logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) @@ -27,52 +27,37 @@ def get_fault_injection_message(fault_injection_type): return fault_message -def create_client_object(device_identity, client_kwargs, DeviceClass, ModuleClass): - - if test_config.config.identity in [ - test_config.IDENTITY_DEVICE, - test_config.IDENTITY_EDGE_LEAF_DEVICE, - ]: - ClientClass = DeviceClass - elif test_config.config.identity in [ - test_config.IDENTITY_MODULE, - test_config.IDENTITY_EDGE_MODULE, - ]: - ClientClass = ModuleClass - else: - raise Exception("config.identity invalid") +def create_session(device_identity, client_kwargs): if test_config.config.auth == test_config.AUTH_CONNECTION_STRING: logger.info( - "Creating {} using create_from_connection_string with kwargs={}".format( - ClientClass, client_kwargs + "Creating session using create_from_connection_string with kwargs={}".format( + client_kwargs ) ) - client = ClientClass.create_from_connection_string( + session = IoTHubSession.from_connection_string( device_identity.connection_string, **client_kwargs ) elif test_config.config.auth == test_config.AUTH_SYMMETRIC_KEY: logger.info( - "Creating {} using create_from_symmetric_key with kwargs={}".format( - ClientClass, client_kwargs - ) + "Creating session using create_from_symmetric_key with kwargs={}".format(client_kwargs) ) - client = ClientClass.create_from_symmetric_key( - device_identity.primary_key, - test_env.IOTHUB_HOSTNAME, - device_identity.device_id, + session = IoTHubSession( + shared_access_key=device_identity.primary_key, + hostname=test_env.IOTHUB_HOSTNAME, + device_id=device_identity.device_id, **client_kwargs ) elif test_config.config.auth == test_config.AUTH_SAS_TOKEN: logger.info( - "Creating {} using create_from_sastoken with kwargs={}".format( - ClientClass, client_kwargs - ) + "Creating session using create_from_sastoken with kwargs={}".format(client_kwargs) ) - client = ClientClass.create_from_sastoken(device_identity.sas_token, **client_kwargs) + # client = ClientClass.create_from_sastoken(device_identity.sas_token, **client_kwargs) + + raise Exception("{} Auth not yet implemented".format(test_config.config.auth)) elif test_config.config.auth in test_config.AUTH_CHOICES: # need to implement @@ -80,7 +65,7 @@ def create_client_object(device_identity, client_kwargs, DeviceClass, ModuleClas else: raise Exception("config.auth invalid") - return client + return session def is_windows(): diff --git a/tests/e2e/provisioning_e2e/connection_string.py b/tests/e2e/provisioning_e2e/connection_string.py index 62366f1b9..916c073b9 100644 --- a/tests/e2e/provisioning_e2e/connection_string.py +++ b/tests/e2e/provisioning_e2e/connection_string.py @@ -1,22 +1,91 @@ -# # ------------------------------------------------------------------------- -# # Copyright (c) Microsoft Corporation. All rights reserved. -# # Licensed under the MIT License. See License.txt in the project root for -# # license information. -# # -------------------------------------------------------------------------- -# from azure.iot.device.common.connection_string import ConnectionString -# from azure.iot.device.common.sastoken import SasToken -# -# -# def connection_string_to_sas_token(conn_str): -# """ -# parse an IoTHub service connection string and return the host and a shared access -# signature that can be used to connect to the given hub -# """ -# conn_str_obj = ConnectionString(conn_str) -# sas_token = SasToken( -# uri=conn_str_obj.get("HostName"), -# key=conn_str_obj.get("SharedAccessKey"), -# key_name=conn_str_obj.get("SharedAccessKeyName"), -# ) -# -# return {"host": conn_str_obj.get("HostName"), "sas": str(sas_token)} +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module contains tools for working with Connection Strings""" + +__all__ = ["ConnectionString"] + +CS_DELIMITER = ";" +CS_VAL_SEPARATOR = "=" + +HOST_NAME = "HostName" +SHARED_ACCESS_KEY_NAME = "SharedAccessKeyName" +SHARED_ACCESS_KEY = "SharedAccessKey" +SHARED_ACCESS_SIGNATURE = "SharedAccessSignature" +DEVICE_ID = "DeviceId" +MODULE_ID = "ModuleId" +GATEWAY_HOST_NAME = "GatewayHostName" + +_valid_keys = [ + HOST_NAME, + SHARED_ACCESS_KEY_NAME, + SHARED_ACCESS_KEY, + SHARED_ACCESS_SIGNATURE, + DEVICE_ID, + MODULE_ID, + GATEWAY_HOST_NAME, +] + + +def _parse_connection_string(connection_string): + """Return a dictionary of values contained in a given connection string""" + cs_args = connection_string.split(CS_DELIMITER) + d = dict(arg.split(CS_VAL_SEPARATOR, 1) for arg in cs_args) + if len(cs_args) != len(d): + # various errors related to incorrect parsing - duplicate args, bad syntax, etc. + raise ValueError("Invalid Connection String - Unable to parse") + if not all(key in _valid_keys for key in d.keys()): + raise ValueError("Invalid Connection String - Invalid Key") + _validate_keys(d) + return d + + +def _validate_keys(d): + """Raise ValueError if incorrect combination of keys in dict d""" + host_name = d.get(HOST_NAME) + shared_access_key_name = d.get(SHARED_ACCESS_KEY_NAME) + shared_access_key = d.get(SHARED_ACCESS_KEY) + device_id = d.get(DEVICE_ID) + + # This logic could be expanded to return the category of ConnectionString + if host_name and device_id and shared_access_key: + pass + elif host_name and shared_access_key and shared_access_key_name: + pass + else: + raise ValueError("Invalid Connection String - Incomplete") + + +class ConnectionString(object): + """Key/value mappings for connection details. + Uses the same syntax as dictionary + """ + + def __init__(self, connection_string): + """Initializer for ConnectionString + + :param str connection_string: String with connection details provided by Azure + :raises: ValueError if provided connection_string is invalid + """ + self._dict = _parse_connection_string(connection_string) + self._strrep = connection_string + + def __getitem__(self, key): + return self._dict[key] + + def __repr__(self): + return self._strrep + + def get(self, key, default=None): + """Return the value for key if key is in the dictionary, else default + + :param str key: The key to retrieve a value for + :param str default: The default value returned if a key is not found + :returns: The value for the given key + """ + try: + return self._dict[key] + except KeyError: + return default diff --git a/tests/e2e/provisioning_e2e/pytest.ini b/tests/e2e/provisioning_e2e/pytest.ini index 2f799d7d7..2973fab51 100644 --- a/tests/e2e/provisioning_e2e/pytest.ini +++ b/tests/e2e/provisioning_e2e/pytest.ini @@ -1,2 +1,3 @@ [pytest] -addopts = --timeout 30 \ No newline at end of file +addopts = --timeout 90 +asyncio_mode=auto diff --git a/tests/e2e/provisioning_e2e/sastoken.py b/tests/e2e/provisioning_e2e/sastoken.py new file mode 100644 index 000000000..1b83dd176 --- /dev/null +++ b/tests/e2e/provisioning_e2e/sastoken.py @@ -0,0 +1,81 @@ +"""This module contains tools for working with Shared Access Signature (SAS) Tokens""" + +import base64 +import hmac +import hashlib +import time + +import urllib.parse + + +class SasTokenError(Exception): + """Error in SasToken""" + + def __init__(self, message, cause=None): + """Initializer for SasTokenError + + :param str message: Error message + :param cause: Exception that caused this error (optional) + """ + super(SasTokenError, self).__init__(message) + self.cause = cause + + +class SasToken(object): + """Shared Access Signature Token used to authenticate a request + + Parameters: + uri (str): URI of the resouce to be accessed + key_name (str): Shared Access Key Name + key (str): Shared Access Key (base64 encoded) + ttl (int)[default 3600]: Time to live for the token, in seconds + + Data Attributes: + expiry_time (int): Time that token will expire (in UTC, since epoch) + ttl (int): Time to live for the token, in seconds + + Raises: + SasTokenError if trying to build a SasToken from invalid values + """ + + _encoding_type = "utf-8" + _service_token_format = "SharedAccessSignature sr={}&sig={}&se={}&skn={}" + _device_token_format = "SharedAccessSignature sr={}&sig={}&se={}" + + def __init__(self, uri, key, key_name=None, ttl=3600): + self._uri = urllib.parse.quote(uri, safe="") + self._key = key + self._key_name = key_name + self.ttl = ttl + self.refresh() + + def __str__(self): + return self._token + + def refresh(self): + """ + Refresh the SasToken lifespan, giving it a new expiry time + """ + self.expiry_time = int(time.time() + self.ttl) + self._token = self._build_token() + + def _build_token(self): + """Buid SasToken representation + + Returns: + String representation of the token + """ + try: + message = (self._uri + "\n" + str(self.expiry_time)).encode(self._encoding_type) + signing_key = base64.b64decode(self._key.encode(self._encoding_type)) + signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) + signature = urllib.parse.quote(base64.b64encode(signed_hmac.digest())) + except (TypeError, base64.binascii.Error) as e: + raise SasTokenError("Unable to build SasToken from given values", e) + if self._key_name: + token = self._service_token_format.format( + self._uri, signature, str(self.expiry_time), self._key_name + ) + else: + token = self._device_token_format.format(self._uri, signature, str(self.expiry_time)) + return token diff --git a/tests/e2e/provisioning_e2e/service_helper.py b/tests/e2e/provisioning_e2e/service_helper.py index 5a5dee738..568f16749 100644 --- a/tests/e2e/provisioning_e2e/service_helper.py +++ b/tests/e2e/provisioning_e2e/service_helper.py @@ -3,15 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from typing import Optional -from provisioning_e2e.iothubservice20180630.iot_hub_gateway_service_ap_is20180630 import ( +from .iothubservice20180630.iot_hub_gateway_service_ap_is20180630 import ( IotHubGatewayServiceAPIs20180630, ) from msrest.exceptions import HttpOperationError -from azure.iot.device.common.auth.connection_string import ConnectionString -from azure.iot.device.common.auth.sastoken import RenewableSasToken -from azure.iot.device.common.auth.signing_mechanism import SymmetricKeySigningMechanism + +from .connection_string import ConnectionString +from .sastoken import SasToken + import uuid import time import random @@ -27,10 +29,9 @@ def connection_string_to_sas_token(conn_str): signature that can be used to connect to the given hub """ conn_str_obj = ConnectionString(conn_str) - signing_mechanism = SymmetricKeySigningMechanism(conn_str_obj.get("SharedAccessKey")) - sas_token = RenewableSasToken( + sas_token = SasToken( uri=conn_str_obj.get("HostName"), - signing_mechanism=signing_mechanism, + key=conn_str_obj.get("SharedAccessKey"), key_name=conn_str_obj.get("SharedAccessKeyName"), ) @@ -46,6 +47,16 @@ def connection_string_to_hostname(conn_str): return conn_str_obj.get("HostName") +def _format_sas_uri(hostname: str, device_id: str, module_id: Optional[str]) -> str: + """Format the SAS URI for using IoT Hub""" + if module_id: + return "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=hostname, device_id=device_id, module_id=module_id + ) + else: + return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) + + def run_with_retry(fun, args, kwargs): failures_left = max_failure_count retry = True @@ -70,7 +81,7 @@ def run_with_retry(fun, args, kwargs): raise e -class Helper: +class ServiceRegistryHelper: def __init__(self, service_connection_string): self.cn = connection_string_to_sas_token(service_connection_string) self.service = IotHubGatewayServiceAPIs20180630("https://" + self.cn["host"]).service diff --git a/tests/e2e/provisioning_e2e/tests/path_adjust.py b/tests/e2e/provisioning_e2e/tests/path_adjust.py index f06decb63..dae92ea4b 100644 --- a/tests/e2e/provisioning_e2e/tests/path_adjust.py +++ b/tests/e2e/provisioning_e2e/tests/path_adjust.py @@ -13,6 +13,7 @@ # no longer true, we can get rid of this file. root_path = dir(dir(sys.path[0])) script_path = os.path.join(root_path, "scripts") +print("The path after scripts is") print(script_path) if script_path not in sys.path: sys.path.append(script_path) diff --git a/tests/e2e/provisioning_e2e/tests/test_async_certificate_enrollments.py b/tests/e2e/provisioning_e2e/tests/test_async_certificate_enrollments.py index 1b8ea4ab3..61c93a783 100644 --- a/tests/e2e/provisioning_e2e/tests/test_async_certificate_enrollments.py +++ b/tests/e2e/provisioning_e2e/tests/test_async_certificate_enrollments.py @@ -5,9 +5,9 @@ # -------------------------------------------------------------------------- -from provisioning_e2e.service_helper import Helper, connection_string_to_hostname -from azure.iot.device.aio import ProvisioningDeviceClient -from azure.iot.device.common import X509 +from ..service_helper import ServiceRegistryHelper, connection_string_to_hostname +from azure.iot.device import ProvisioningSession + from provisioningserviceclient import ( ProvisioningServiceClient, IndividualEnrollment, @@ -18,18 +18,17 @@ import logging import os import uuid - +import ssl from . import path_adjust # noqa: F401 # Refers to an item in "scripts" in the root. This is made to work via the above path_adjust -from create_x509_chain_crypto import ( +from scripts.create_x509_chain_crypto import ( before_cert_creation_from_pipeline, call_intermediate_cert_and_device_cert_creation_from_pipeline, delete_directories_certs_created_from_pipeline, ) -pytestmark = pytest.mark.asyncio logging.basicConfig(level=logging.DEBUG) @@ -41,7 +40,7 @@ service_client = ProvisioningServiceClient.create_from_connection_string( os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING") ) -device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING")) +device_registry_helper = ServiceRegistryHelper(os.getenv("IOTHUB_CONNECTION_STRING")) linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING")) PROVISIONING_HOST = os.getenv("PROVISIONING_DEVICE_ENDPOINT") @@ -96,6 +95,7 @@ async def test_device_register_with_device_id_for_a_x509_individual_enrollment(p registration_id, device_cert_file, device_key_file, protocol ) + assert registration_result is not None assert device_id != registration_id assert_device_provisioned(device_id=device_id, registration_result=registration_result) device_registry_helper.try_delete_device(device_id) @@ -122,6 +122,7 @@ async def test_device_register_with_no_device_id_for_a_x509_individual_enrollmen registration_id, device_cert_file, device_key_file, protocol ) + assert registration_result is not None assert_device_provisioned( device_id=registration_id, registration_result=registration_result ) @@ -180,6 +181,7 @@ async def test_group_of_devices_register_with_no_device_id_for_a_x509_intermedia protocol=protocol, ) + assert registration_result is not None assert_device_provisioned(device_id=device_id, registration_result=registration_result) device_registry_helper.try_delete_device(device_id) @@ -243,7 +245,7 @@ async def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authent device_key_file=device_key_input_file, protocol=protocol, ) - + assert registration_result is not None assert_device_provisioned(device_id=device_id, registration_result=registration_result) device_registry_helper.try_delete_device(device_id) @@ -259,9 +261,9 @@ def assert_device_provisioned(device_id, registration_result): :param device_id: The device id :param registration_result: The registration result """ - assert registration_result.status == "assigned" - assert registration_result.registration_state.device_id == device_id - assert registration_result.registration_state.assigned_hub == linked_iot_hub + assert registration_result["status"] == "assigned" + assert registration_result["registrationState"]["deviceId"] == device_id + assert registration_result["registrationState"]["assignedHub"] == linked_iot_hub device = device_registry_helper.get_device(device_id) assert device is not None @@ -290,14 +292,31 @@ def create_individual_enrollment_with_x509_client_certs(device_index, device_id= async def result_from_register(registration_id, device_cert_file, device_key_file, protocol): - x509 = X509(cert_file=device_cert_file, key_file=device_key_file, pass_phrase=device_password) + # We have this mapping because the pytest logs look better with "mqtt" and "mqttws" + # instead of just "True" and "False". protocol_boolean_mapping = {"mqtt": False, "mqttws": True} - provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=PROVISIONING_HOST, + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + ssl_context.load_default_certs() + ssl_context.load_cert_chain( + certfile=device_cert_file, + keyfile=device_key_file, + password=device_password, + ) + + async with ProvisioningSession( + provisioning_endpoint=PROVISIONING_HOST, registration_id=registration_id, id_scope=ID_SCOPE, - x509=x509, + ssl_context=ssl_context, websockets=protocol_boolean_mapping[protocol], - ) - - return await provisioning_device_client.register() + ) as session: + print("Connected") + properties = {"Type": "Apple", "Sweet": True, "count": 5} + result = await session.register(payload=properties) + print("Finished provisioning") + print(result) + result = await session.register() + return result if result is not None else None diff --git a/tests/e2e/provisioning_e2e/tests/test_async_symmetric_enrollments.py b/tests/e2e/provisioning_e2e/tests/test_async_symmetric_enrollments.py index 3657db15b..ff0b05974 100644 --- a/tests/e2e/provisioning_e2e/tests/test_async_symmetric_enrollments.py +++ b/tests/e2e/provisioning_e2e/tests/test_async_symmetric_enrollments.py @@ -4,8 +4,7 @@ # license information. # -------------------------------------------------------------------------- -from provisioning_e2e.service_helper import Helper, connection_string_to_hostname -from azure.iot.device.aio import ProvisioningDeviceClient +from ..service_helper import ServiceRegistryHelper, connection_string_to_hostname from provisioningserviceclient import ProvisioningServiceClient, IndividualEnrollment from provisioningserviceclient.protocol.models import AttestationMechanism, ReprovisionPolicy import pytest @@ -13,7 +12,8 @@ import os import uuid -pytestmark = pytest.mark.asyncio +from azure.iot.device import ProvisioningSession + logging.basicConfig(level=logging.DEBUG) @@ -24,7 +24,7 @@ os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING") ) service_client = ProvisioningServiceClient.create_from_connection_string(conn_str) -device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING")) +device_registry_helper = ServiceRegistryHelper(os.getenv("IOTHUB_CONNECTION_STRING")) linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING")) @@ -46,6 +46,7 @@ async def test_device_register_with_no_device_id_for_a_symmetric_key_individual_ registration_result = await result_from_register(registration_id, symmetric_key, protocol) + assert registration_result is not None assert_device_provisioned( device_id=registration_id, registration_result=registration_result ) @@ -71,6 +72,7 @@ async def test_device_register_with_device_id_for_a_symmetric_key_individual_enr registration_result = await result_from_register(registration_id, symmetric_key, protocol) + assert registration_result is not None assert device_id != registration_id assert_device_provisioned(device_id=device_id, registration_result=registration_result) device_registry_helper.try_delete_device(device_id) @@ -104,9 +106,10 @@ def assert_device_provisioned(device_id, registration_result): :param device_id: The device id :param registration_result: The registration result """ - assert registration_result.status == "assigned" - assert registration_result.registration_state.device_id == device_id - assert registration_result.registration_state.assigned_hub == linked_iot_hub + print(registration_result) + assert registration_result["status"] == "assigned" + assert registration_result["registrationState"]["deviceId"] == device_id + assert registration_result["registrationState"]["assignedHub"] == linked_iot_hub device = device_registry_helper.get_device(device_id) assert device is not None @@ -114,17 +117,16 @@ def assert_device_provisioned(device_id, registration_result): assert device.device_id == device_id -# TODO Eventually should return result after the APi changes async def result_from_register(registration_id, symmetric_key, protocol): # We have this mapping because the pytest logs look better with "mqtt" and "mqttws" # instead of just "True" and "False". protocol_boolean_mapping = {"mqtt": False, "mqttws": True} - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=PROVISIONING_HOST, + async with ProvisioningSession( + provisioning_endpoint=PROVISIONING_HOST, registration_id=registration_id, id_scope=ID_SCOPE, - symmetric_key=symmetric_key, + shared_access_key=symmetric_key, websockets=protocol_boolean_mapping[protocol], - ) - - return await provisioning_device_client.register() + ) as session: + result = await session.register() + return result if result is not None else None diff --git a/tests/e2e/provisioning_e2e/tests/test_sync_certificate_enrollments.py b/tests/e2e/provisioning_e2e/tests/test_sync_certificate_enrollments.py deleted file mode 100644 index e8d039e19..000000000 --- a/tests/e2e/provisioning_e2e/tests/test_sync_certificate_enrollments.py +++ /dev/null @@ -1,302 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from provisioning_e2e.service_helper import Helper, connection_string_to_hostname -from azure.iot.device import ProvisioningDeviceClient -from azure.iot.device.common import X509 -from provisioningserviceclient import ( - ProvisioningServiceClient, - IndividualEnrollment, - EnrollmentGroup, -) -from provisioningserviceclient.protocol.models import AttestationMechanism, ReprovisionPolicy -import pytest -import logging -import os -import uuid - -from . import path_adjust # noqa: F401 - -# Refers to an item in "scripts" in the root. This is made to work via the above path_adjust -from create_x509_chain_crypto import ( - before_cert_creation_from_pipeline, - call_intermediate_cert_and_device_cert_creation_from_pipeline, - delete_directories_certs_created_from_pipeline, -) - - -logging.basicConfig(level=logging.DEBUG) - - -intermediate_common_name = "e2edpswingardium" -intermediate_password = "leviosa" -device_common_name = "e2edpsexpecto" + str(uuid.uuid4()) -device_password = "patronum" - -service_client = ProvisioningServiceClient.create_from_connection_string( - os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING") -) -device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING")) -linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING")) -PROVISIONING_HOST = os.getenv("PROVISIONING_DEVICE_ENDPOINT") -ID_SCOPE = os.getenv("PROVISIONING_DEVICE_IDSCOPE") - - -certificate_count = 8 -type_to_device_indices = { - "individual_with_device_id": [1], - "individual_no_device_id": [2], - "group_intermediate": [3, 4, 5], - "group_ca": [6, 7, 8], -} - - -@pytest.fixture(scope="module", autouse=True) -def before_all_tests(request): - logging.info("set up certificates before cert related tests") - before_cert_creation_from_pipeline() - call_intermediate_cert_and_device_cert_creation_from_pipeline( - intermediate_common_name=intermediate_common_name, - device_common_name=device_common_name, - ca_password=os.getenv("PROVISIONING_ROOT_PASSWORD"), - intermediate_password=intermediate_password, - device_password=device_password, - device_count=8, - ) - - def after_module(): - logging.info("tear down certificates after cert related tests") - delete_directories_certs_created_from_pipeline() - - request.addfinalizer(after_module) - - -@pytest.mark.it( - "A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication" -) -@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) -def test_device_register_with_device_id_for_a_x509_individual_enrollment(protocol): - device_id = "e2edpsflyingfeather" - device_index = type_to_device_indices.get("individual_with_device_id")[0] - - try: - individual_enrollment_record = create_individual_enrollment_with_x509_client_certs( - device_index=device_index, device_id=device_id - ) - registration_id = individual_enrollment_record.registration_id - - device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem" - device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem" - registration_result = result_from_register( - registration_id, device_cert_file, device_key_file, protocol - ) - - assert device_id != registration_id - assert_device_provisioned(device_id=device_id, registration_result=registration_result) - device_registry_helper.try_delete_device(device_id) - finally: - service_client.delete_individual_enrollment_by_param(registration_id) - - -@pytest.mark.it( - "A device gets provisioned to the linked IoTHub with device_id equal to the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication" -) -@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) -def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(protocol): - device_index = type_to_device_indices.get("individual_no_device_id")[0] - - try: - individual_enrollment_record = create_individual_enrollment_with_x509_client_certs( - device_index=device_index - ) - registration_id = individual_enrollment_record.registration_id - - device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem" - device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem" - registration_result = result_from_register( - registration_id, device_cert_file, device_key_file, protocol - ) - - assert_device_provisioned( - device_id=registration_id, registration_result=registration_result - ) - device_registry_helper.try_delete_device(registration_id) - finally: - service_client.delete_individual_enrollment_by_param(registration_id) - - -@pytest.mark.it( - "A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with intermediate X509 authentication" -) -@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) -def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_authentication_group_enrollment( - protocol, -): - group_id = "e2e-intermediate-hogwarts" + str(uuid.uuid4()) - common_device_id = device_common_name - devices_indices = type_to_device_indices.get("group_intermediate") - device_count_in_group = len(devices_indices) - reprovision_policy = ReprovisionPolicy(migrate_device_data=True) - - try: - intermediate_cert_filename = "demoCA/newcerts/intermediate_cert.pem" - with open(intermediate_cert_filename, "r") as intermediate_pem: - intermediate_cert_content = intermediate_pem.read() - - attestation_mechanism = AttestationMechanism.create_with_x509_signing_certs( - intermediate_cert_content - ) - enrollment_group_provisioning_model = EnrollmentGroup.create( - group_id, attestation=attestation_mechanism, reprovision_policy=reprovision_policy - ) - - service_client.create_or_update(enrollment_group_provisioning_model) - - count = 0 - common_device_key_input_file = "demoCA/private/device_key" - common_device_cert_input_file = "demoCA/newcerts/device_cert" - common_device_inter_cert_chain_file = "demoCA/newcerts/out_inter_device_chain_cert" - for index in devices_indices: - count = count + 1 - device_id = common_device_id + str(index) - device_key_input_file = common_device_key_input_file + str(index) + ".pem" - device_cert_input_file = common_device_cert_input_file + str(index) + ".pem" - device_inter_cert_chain_file = common_device_inter_cert_chain_file + str(index) + ".pem" - - filenames = [device_cert_input_file, intermediate_cert_filename] - with open(device_inter_cert_chain_file, "w") as outfile: - for fname in filenames: - with open(fname) as infile: - outfile.write(infile.read()) - - registration_result = result_from_register( - registration_id=device_id, - device_cert_file=device_inter_cert_chain_file, - device_key_file=device_key_input_file, - protocol=protocol, - ) - - assert_device_provisioned(device_id=device_id, registration_result=registration_result) - device_registry_helper.try_delete_device(device_id) - - # Make sure space is okay. The following line must be outside for loop. - assert count == device_count_in_group - - finally: - service_client.delete_enrollment_group_by_param(group_id) - - -@pytest.mark.skip( - reason="The enrollment is never properly created on the pipeline and it is always created without any CA reference and eventually the registration fails" -) -@pytest.mark.it( - "A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with an already uploaded ca cert X509 authentication" -) -@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) -def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authentication_group_enrollment( - protocol, -): - group_id = "e2e-ca-beauxbatons" + str(uuid.uuid4()) - common_device_id = device_common_name - devices_indices = type_to_device_indices.get("group_ca") - device_count_in_group = len(devices_indices) - reprovision_policy = ReprovisionPolicy(migrate_device_data=True) - - try: - DPS_GROUP_CA_CERT = os.getenv("PROVISIONING_ROOT_CERT") - attestation_mechanism = AttestationMechanism.create_with_x509_ca_refs( - ref1=DPS_GROUP_CA_CERT - ) - enrollment_group_provisioning_model = EnrollmentGroup.create( - group_id, attestation=attestation_mechanism, reprovision_policy=reprovision_policy - ) - - service_client.create_or_update(enrollment_group_provisioning_model) - - count = 0 - intermediate_cert_filename = "demoCA/newcerts/intermediate_cert.pem" - common_device_key_input_file = "demoCA/private/device_key" - common_device_cert_input_file = "demoCA/newcerts/device_cert" - common_device_inter_cert_chain_file = "demoCA/newcerts/out_inter_device_chain_cert" - for index in devices_indices: - count = count + 1 - device_id = common_device_id + str(index) - device_key_input_file = common_device_key_input_file + str(index) + ".pem" - device_cert_input_file = common_device_cert_input_file + str(index) + ".pem" - device_inter_cert_chain_file = common_device_inter_cert_chain_file + str(index) + ".pem" - filenames = [device_cert_input_file, intermediate_cert_filename] - with open(device_inter_cert_chain_file, "w") as outfile: - for fname in filenames: - with open(fname) as infile: - logging.debug("Filename is {}".format(fname)) - content = infile.read() - logging.debug(content) - outfile.write(content) - - registration_result = result_from_register( - registration_id=device_id, - device_cert_file=device_inter_cert_chain_file, - device_key_file=device_key_input_file, - protocol=protocol, - ) - - assert_device_provisioned(device_id=device_id, registration_result=registration_result) - device_registry_helper.try_delete_device(device_id) - - # Make sure space is okay. The following line must be outside for loop. - assert count == device_count_in_group - finally: - service_client.delete_enrollment_group_by_param(group_id) - - -def assert_device_provisioned(device_id, registration_result): - """ - Assert that the device has been provisioned correctly to iothub from the registration result as well as from the device registry - :param device_id: The device id - :param registration_result: The registration result - """ - assert registration_result.status == "assigned" - assert registration_result.registration_state.device_id == device_id - assert registration_result.registration_state.assigned_hub == linked_iot_hub - - device = device_registry_helper.get_device(device_id) - assert device is not None - assert device.authentication.type == "selfSigned" - assert device.device_id == device_id - - -def create_individual_enrollment_with_x509_client_certs(device_index, device_id=None): - registration_id = device_common_name + str(device_index) - reprovision_policy = ReprovisionPolicy(migrate_device_data=True) - - device_cert_input_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem" - with open(device_cert_input_file, "r") as in_device_cert: - device_cert_content = in_device_cert.read() - - attestation_mechanism = AttestationMechanism.create_with_x509_client_certs(device_cert_content) - - individual_provisioning_model = IndividualEnrollment.create( - attestation=attestation_mechanism, - registration_id=registration_id, - reprovision_policy=reprovision_policy, - device_id=device_id, - ) - - return service_client.create_or_update(individual_provisioning_model) - - -def result_from_register(registration_id, device_cert_file, device_key_file, protocol): - x509 = X509(cert_file=device_cert_file, key_file=device_key_file, pass_phrase=device_password) - protocol_boolean_mapping = {"mqtt": False, "mqttws": True} - provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=PROVISIONING_HOST, - registration_id=registration_id, - id_scope=ID_SCOPE, - x509=x509, - websockets=protocol_boolean_mapping[protocol], - ) - - return provisioning_device_client.register() diff --git a/tests/e2e/provisioning_e2e/tests/test_sync_symmetric_enrollments.py b/tests/e2e/provisioning_e2e/tests/test_sync_symmetric_enrollments.py deleted file mode 100644 index 29fb0d168..000000000 --- a/tests/e2e/provisioning_e2e/tests/test_sync_symmetric_enrollments.py +++ /dev/null @@ -1,120 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from provisioning_e2e.service_helper import Helper, connection_string_to_hostname -from azure.iot.device import ProvisioningDeviceClient -from provisioningserviceclient import ProvisioningServiceClient, IndividualEnrollment -from provisioningserviceclient.protocol.models import AttestationMechanism, ReprovisionPolicy -import pytest -import logging -import os -import uuid - -logging.basicConfig(level=logging.DEBUG) - -PROVISIONING_HOST = os.getenv("PROVISIONING_DEVICE_ENDPOINT") -ID_SCOPE = os.getenv("PROVISIONING_DEVICE_IDSCOPE") -service_client = ProvisioningServiceClient.create_from_connection_string( - os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING") -) -device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING")) -linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING")) - - -@pytest.mark.it( - "A device gets provisioned to the linked IoTHub with the device_id equal to the registration_id of the individual enrollment that has been created with a symmetric key authentication" -) -@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) -def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enrollment(protocol): - try: - individual_enrollment_record = create_individual_enrollment( - "e2e-dps-underthewhompingwillow" + str(uuid.uuid4()) - ) - - registration_id = individual_enrollment_record.registration_id - symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key - - registration_result = result_from_register(registration_id, symmetric_key, protocol) - - assert_device_provisioned( - device_id=registration_id, registration_result=registration_result - ) - device_registry_helper.try_delete_device(registration_id) - finally: - service_client.delete_individual_enrollment_by_param(registration_id) - - -@pytest.mark.it( - "A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a symmetric key authentication" -) -@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) -def test_device_register_with_device_id_for_a_symmetric_key_individual_enrollment(protocol): - - device_id = "e2edpstommarvoloriddle" - try: - individual_enrollment_record = create_individual_enrollment( - registration_id="e2e-dps-prioriincantatem" + str(uuid.uuid4()), device_id=device_id - ) - - registration_id = individual_enrollment_record.registration_id - symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key - - registration_result = result_from_register(registration_id, symmetric_key, protocol) - - assert device_id != registration_id - assert_device_provisioned(device_id=device_id, registration_result=registration_result) - device_registry_helper.try_delete_device(device_id) - finally: - service_client.delete_individual_enrollment_by_param(registration_id) - - -def create_individual_enrollment(registration_id, device_id=None): - """ - Create an individual enrollment record using the service client - :param registration_id: The registration id of the enrollment - :param device_id: Optional device id - :return: And individual enrollment record - """ - reprovision_policy = ReprovisionPolicy(migrate_device_data=True) - attestation_mechanism = AttestationMechanism(type="symmetricKey") - - individual_provisioning_model = IndividualEnrollment.create( - attestation=attestation_mechanism, - registration_id=registration_id, - device_id=device_id, - reprovision_policy=reprovision_policy, - ) - - return service_client.create_or_update(individual_provisioning_model) - - -def assert_device_provisioned(device_id, registration_result): - """ - Assert that the device has been provisioned correctly to iothub from the registration result as well as from the device registry - :param device_id: The device id - :param registration_result: The registration result - """ - assert registration_result.status == "assigned" - assert registration_result.registration_state.device_id == device_id - assert registration_result.registration_state.assigned_hub == linked_iot_hub - - device = device_registry_helper.get_device(device_id) - assert device is not None - assert device.authentication.type == "sas" - assert device.device_id == device_id - - -def result_from_register(registration_id, symmetric_key, protocol): - protocol_boolean_mapping = {"mqtt": False, "mqttws": True} - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=PROVISIONING_HOST, - registration_id=registration_id, - id_scope=ID_SCOPE, - symmetric_key=symmetric_key, - websockets=protocol_boolean_mapping[protocol], - ) - - return provisioning_device_client.register() diff --git a/tests/unit/common/auth/test_sastoken.py b/tests/unit/common/auth/test_sastoken.py deleted file mode 100644 index 5d8688190..000000000 --- a/tests/unit/common/auth/test_sastoken.py +++ /dev/null @@ -1,285 +0,0 @@ -# -*- coding: utf-8 -*- -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import time -import re -import logging -import urllib -from azure.iot.device.common.auth.sastoken import ( - RenewableSasToken, - NonRenewableSasToken, - SasTokenError, -) - -logging.basicConfig(level=logging.DEBUG) - -fake_uri = "some/resource/location" -fake_signed_data = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" -fake_key_name = "fakekeyname" -fake_expiry = 12321312 - -simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}" -auth_rule_token_format = ( - "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}" -) - - -def token_parser(token_str): - """helper function that parses a token string for individual values""" - token_map = {} - kv_string = token_str.split(" ")[1] - kv_pairs = kv_string.split("&") - for kv in kv_pairs: - t = kv.split("=") - token_map[t[0]] = t[1] - return token_map - - -class RenewableSasTokenTestConfig(object): - @pytest.fixture - def signing_mechanism(self, mocker): - mechanism = mocker.MagicMock() - mechanism.sign.return_value = fake_signed_data - return mechanism - - # TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic - @pytest.fixture(params=["Device Token", "Service Token"]) - def sastoken(self, request, signing_mechanism): - token_type = request.param - if token_type == "Device Token": - return RenewableSasToken(uri=fake_uri, signing_mechanism=signing_mechanism) - elif token_type == "Service Token": - return RenewableSasToken( - uri=fake_uri, signing_mechanism=signing_mechanism, key_name=fake_key_name - ) - - -@pytest.mark.describe("RenewableSasToken") -class TestRenewableSasToken(RenewableSasTokenTestConfig): - @pytest.mark.it("Instantiates with a default TTL of 3600 seconds if no TTL is provided") - def test_default_ttl(self, signing_mechanism): - s = RenewableSasToken(fake_uri, signing_mechanism) - assert s.ttl == 3600 - - @pytest.mark.it("Instantiates with a custom TTL if provided") - def test_custom_ttl(self, signing_mechanism): - custom_ttl = 4747 - s = RenewableSasToken(fake_uri, signing_mechanism, ttl=custom_ttl) - assert s.ttl == custom_ttl - - @pytest.mark.it("Instantiates with with no key name by default if no key name is provided") - def test_default_key_name(self, signing_mechanism): - s = RenewableSasToken(fake_uri, signing_mechanism) - assert s._key_name is None - - @pytest.mark.it("Instantiates with the given key name if provided") - def test_custom_key_name(self, signing_mechanism): - s = RenewableSasToken(fake_uri, signing_mechanism, key_name=fake_key_name) - assert s._key_name == fake_key_name - - @pytest.mark.it( - "Instantiates with an expiry time TTL seconds in the future from the moment of instantiation" - ) - def test_expiry_time(self, mocker, signing_mechanism): - fake_current_time = 1000 - mocker.patch.object(time, "time", return_value=fake_current_time) - - s = RenewableSasToken(fake_uri, signing_mechanism) - assert s.expiry_time == fake_current_time + s.ttl - - @pytest.mark.it("Calls .refresh() to build the SAS token string on instantiation") - def test_refresh_on_instantiation(self, mocker, signing_mechanism): - refresh_mock = mocker.spy(RenewableSasToken, "refresh") - assert refresh_mock.call_count == 0 - RenewableSasToken(fake_uri, signing_mechanism) - assert refresh_mock.call_count == 1 - - @pytest.mark.it("Returns the SAS token string as the string representation of the object") - def test_str_rep(self, sastoken): - assert str(sastoken) == sastoken._token - - @pytest.mark.it( - "Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)" - ) - def test_expiry_time_read_only(self, sastoken): - with pytest.raises(AttributeError): - sastoken.expiry_time = 12321312 - - -@pytest.mark.describe("RenewableSasToken - .refresh()") -class TestRenewableSasTokenRefresh(RenewableSasTokenTestConfig): - @pytest.mark.it("Sets a new expiry time of TTL seconds in the future") - def test_new_expiry(self, mocker, sastoken): - fake_current_time = 1000 - mocker.patch.object(time, "time", return_value=fake_current_time) - sastoken.refresh() - assert sastoken.expiry_time == fake_current_time + sastoken.ttl - - # TODO: reflect url encoding here? - @pytest.mark.it( - "Uses the token's signing mechanism to create a signature by signing a concatenation of the (URL encoded) URI and updated expiry time" - ) - def test_generate_new_token(self, mocker, signing_mechanism, sastoken): - old_token_str = str(sastoken) - fake_future_time = 1000 - mocker.patch.object(time, "time", return_value=fake_future_time) - signing_mechanism.reset_mock() - fake_signature = "new_fake_signature" - signing_mechanism.sign.return_value = fake_signature - - sastoken.refresh() - - # The token string has been updated - assert str(sastoken) != old_token_str - # The signing mechanism was used to sign a string - assert signing_mechanism.sign.call_count == 1 - # The string being signed was a concatenation of the URI and expiry time - assert signing_mechanism.sign.call_args == mocker.call( - urllib.parse.quote(sastoken._uri, safe="") + "\n" + str(sastoken.expiry_time) - ) - # The token string has the resulting signed string included as the signature - token_info = token_parser(str(sastoken)) - assert token_info["sig"] == fake_signature - - @pytest.mark.it( - "Builds a new token string using the token's URI (URL encoded) and expiry time, along with the signature created by the signing mechanism (also URL encoded)" - ) - def test_token_string(self, sastoken): - token_str = sastoken._token - - # Verify that token string representation matches token format - if not sastoken._key_name: - pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)") - else: - pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)&skn=(.+)") - assert pattern.match(token_str) - - # Verify that content in the string representation is correct - token_info = token_parser(token_str) - assert token_info["sr"] == urllib.parse.quote(sastoken._uri, safe="") - assert token_info["sig"] == urllib.parse.quote( - sastoken._signing_mechanism.sign.return_value, safe="" - ) - assert token_info["se"] == str(sastoken.expiry_time) - if sastoken._key_name: - assert token_info["skn"] == sastoken._key_name - - @pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism") - def test_signing_mechanism_raises_value_error( - self, mocker, signing_mechanism, sastoken, arbitrary_exception - ): - signing_mechanism.sign.side_effect = arbitrary_exception - - with pytest.raises(SasTokenError) as e_info: - sastoken.refresh() - assert e_info.value.__cause__ is arbitrary_exception - - -@pytest.mark.describe("NonRenewableSasToken") -class TestNonRenewableSasToken(object): - # TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic - @pytest.fixture(params=["Device Token", "Service Token"]) - def sastoken_str(self, request): - token_type = request.param - if token_type == "Device Token": - return simple_token_format.format( - resource=urllib.parse.quote(fake_uri, safe=""), - signature=urllib.parse.quote(fake_signed_data, safe=""), - expiry=fake_expiry, - ) - elif token_type == "Service Token": - return auth_rule_token_format.format( - resource=urllib.parse.quote(fake_uri, safe=""), - signature=urllib.parse.quote(fake_signed_data, safe=""), - expiry=fake_expiry, - keyname=fake_key_name, - ) - - @pytest.fixture() - def sastoken(self, sastoken_str): - return NonRenewableSasToken(sastoken_str) - - @pytest.mark.it("Instantiates from a valid SAS Token string") - def test_instantiates_from_token_string(self, sastoken_str): - s = NonRenewableSasToken(sastoken_str) - assert s._token == sastoken_str - - @pytest.mark.it("Raises a SasToken error if instantiating from an invalid SAS Token string") - @pytest.mark.parametrize( - "invalid_token_str", - [ - pytest.param( - "sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312", - id="Incomplete token format", - ), - pytest.param( - "SharedERRORSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312", - id="Invalid token format", - ), - pytest.param( - "SharedAccessignature sr=some%2Fresource%2Flocationsig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se12321312", - id="Token values incorectly formatted", - ), - pytest.param( - "SharedAccessSignature sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312", - id="Missing resource value", - ), - pytest.param( - "SharedAccessSignature sr=some%2Fresource%2Flocation&se=12321312", - id="Missing signature value", - ), - pytest.param( - "SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", - id="Missing expiry value", - ), - pytest.param( - "SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312&foovalue=nonsense", - id="Extraneous invalid value", - ), - ], - ) - def test_raises_error_invalid_token_string(self, invalid_token_str): - with pytest.raises(SasTokenError): - NonRenewableSasToken(invalid_token_str) - - @pytest.mark.it("Returns the SAS token string as the string representation of the object") - def test_str_rep(self, sastoken_str): - sastoken = NonRenewableSasToken(sastoken_str) - assert str(sastoken) == sastoken_str - - @pytest.mark.it( - "Instantiates with the .expiry_time attribute corresponding to the expiry time of the given SAS Token string (as an integer)" - ) - def test_instantiates_expiry_time(self, sastoken_str): - sastoken = NonRenewableSasToken(sastoken_str) - expected_expiry_time = token_parser(sastoken_str)["se"] - assert sastoken.expiry_time == int(expected_expiry_time) - - @pytest.mark.it( - "Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)" - ) - def test_expiry_time_read_only(self, sastoken): - with pytest.raises(AttributeError): - sastoken.expiry_time = 12312312312123 - - @pytest.mark.it( - "Instantiates with the .resource_uri attribute corresponding to the URL decoded URI of the given SAS Token string" - ) - def test_instantiates_resource_uri(self, sastoken_str): - sastoken = NonRenewableSasToken(sastoken_str) - resource_uri = token_parser(sastoken_str)["sr"] - assert resource_uri != sastoken.resource_uri - assert resource_uri == urllib.parse.quote(sastoken.resource_uri, safe="") - assert urllib.parse.unquote(resource_uri) == sastoken.resource_uri - - @pytest.mark.it( - "Maintains the .resource_uri attribute as a read-only property (raises AttributeError upon attempt)" - ) - def test_resource_uri_read_only(self, sastoken): - with pytest.raises(AttributeError): - sastoken.resource_uri = "new%2Ffake%2Furi" diff --git a/tests/unit/common/models/test_proxy_options.py b/tests/unit/common/models/test_proxy_options.py deleted file mode 100644 index 28a37fb49..000000000 --- a/tests/unit/common/models/test_proxy_options.py +++ /dev/null @@ -1,152 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -import socks -from azure.iot.device.common.models import ProxyOptions - -logging.basicConfig(level=logging.DEBUG) - - -@pytest.mark.describe("ProxyOptions") -class TestProxyOptions(object): - @pytest.mark.it( - "Instantiates with the 'proxy_type' and 'proxy_type_socks' properties set based on the value of the 'proxy_type' parameter" - ) - @pytest.mark.parametrize( - "proxy_type_input, expected_proxy_type, expected_proxy_type_socks", - [ - pytest.param("HTTP", "HTTP", socks.HTTP, id="HTTP (string)"), - pytest.param("SOCKS4", "SOCKS4", socks.SOCKS4, id="SOCKS4 (string)"), - pytest.param("SOCKS5", "SOCKS5", socks.SOCKS5, id="SOCKS5 (string)"), - # Backwards compatibility - pytest.param(socks.HTTP, "HTTP", socks.HTTP, id="HTTP (socks constant)"), - pytest.param(socks.SOCKS4, "SOCKS4", socks.SOCKS4, id="SOCKS4 (socks constant)"), - pytest.param(socks.SOCKS5, "SOCKS5", socks.SOCKS5, id="SOCKS5 (socks constant)"), - ], - ) - def test_proxy_type(self, proxy_type_input, expected_proxy_type, expected_proxy_type_socks): - proxy_options = ProxyOptions( - proxy_type=proxy_type_input, proxy_addr="127.0.0.1", proxy_port=8888 - ) - - assert proxy_options.proxy_type == expected_proxy_type - assert proxy_options.proxy_type_socks == expected_proxy_type_socks - - @pytest.mark.it("Maintains 'proxy_type' as a read-only property") - def test_proxy_type_read_only(self): - proxy_options = ProxyOptions(proxy_type="HTTP", proxy_addr="127.0.0.1", proxy_port=8888) - with pytest.raises(AttributeError): - proxy_options.proxy_type = "new value" - - @pytest.mark.it("Maintains 'proxy_type_socks' as a read-only property") - def test_proxy_type_socks_read_only(self): - proxy_options = ProxyOptions(proxy_type="HTTP", proxy_addr="127.0.0.1", proxy_port=8888) - with pytest.raises(AttributeError): - proxy_options.proxy_type_socks = "new value" - - @pytest.mark.it("Raises a ValueError if proxy_type is invalid") - def test_invalid_proxy_type(self): - with pytest.raises(ValueError): - ProxyOptions(proxy_type="INVALID", proxy_addr="127.0.0.1", proxy_port=8888) - - @pytest.mark.it( - "Instantiates with the 'proxy_address' property set to the value of the 'proxy_addr' parameter" - ) - def test_proxy_address(self): - proxy_addr = "127.0.0.1" - proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr=proxy_addr, proxy_port=8888) - - assert proxy_options.proxy_address == proxy_addr - - @pytest.mark.it("Maintains 'proxy_address' as a read-only property") - def test_proxy_address_read_only(self): - proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) - with pytest.raises(AttributeError): - proxy_options.proxy_address = "new value" - - @pytest.mark.it( - "Instantiates with the 'proxy_port' property set to the value of the 'proxy_port' parameter" - ) - def test_proxy_port(self): - proxy_port = 8888 - proxy_options = ProxyOptions( - proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=proxy_port - ) - - assert proxy_options.proxy_port == proxy_port - - @pytest.mark.it( - "Converts the 'proxy_port' property to an integer if the 'proxy_port' parameter is provided as a string" - ) - def test_proxy_port_conversion(self): - proxy_port = "8888" - proxy_options = ProxyOptions( - proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=proxy_port - ) - - assert proxy_options.proxy_port == int(proxy_port) - - @pytest.mark.it("Maintains 'proxy_port' as a read-only property") - def test_proxy_port_read_only(self): - proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) - with pytest.raises(AttributeError): - proxy_options.proxy_port = "new value" - - @pytest.mark.it( - "Instantiates with the 'proxy_username' property set to the value of the 'proxy_username' parameter, if provided" - ) - def test_proxy_username(self): - proxy_username = "myusername" - proxy_options = ProxyOptions( - proxy_type=socks.HTTP, - proxy_addr="127.0.0.1", - proxy_port=8888, - proxy_username=proxy_username, - ) - - assert proxy_options.proxy_username == proxy_username - - @pytest.mark.it( - "Defaults the 'proxy_username' property to 'None' if the 'proxy_username' parameter is not provided" - ) - def test_proxy_username_default(self): - proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) - assert proxy_options.proxy_username is None - - @pytest.mark.it("Maintains 'proxy_username' as a read-only property") - def test_proxy_username_read_only(self): - proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) - with pytest.raises(AttributeError): - proxy_options.proxy_username = "new value" - - @pytest.mark.it( - "Instantiates with the 'proxy_password' property set to the value of the 'proxy_password' parameter, if provided" - ) - def test_proxy_password(self): - proxy_password = "fake_password" - proxy_options = ProxyOptions( - proxy_type=socks.HTTP, - proxy_addr="127.0.0.1", - proxy_port=8888, - proxy_password=proxy_password, - ) - - assert proxy_options.proxy_password == proxy_password - - @pytest.mark.it( - "Defaults the 'proxy_password' property to 'None' if the 'proxy_password' parameter is not provided" - ) - def test_proxy_password_default(self): - proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) - assert proxy_options.proxy_password is None - - @pytest.mark.it("Maintains 'proxy_password' as a read-only property") - def test_proxy_password_read_only(self): - proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) - with pytest.raises(AttributeError): - proxy_options.proxy_password = "new value" diff --git a/tests/unit/common/pipeline/__init__.py b/tests/unit/common/pipeline/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/common/pipeline/config_test.py b/tests/unit/common/pipeline/config_test.py deleted file mode 100644 index ee7430ee1..000000000 --- a/tests/unit/common/pipeline/config_test.py +++ /dev/null @@ -1,384 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import abc -import threading -from azure.iot.device import ProxyOptions -from azure.iot.device.common.pipeline.config import DEFAULT_KEEPALIVE - - -class PipelineConfigInstantiationTestBase(abc.ABC): - """All PipelineConfig instantiation tests should inherit from this base class. - It provides tests for shared functionality among all PipelineConfigs, derived from - the BasePipelineConfig class. - """ - - @abc.abstractmethod - def config_cls(self): - """This must be implemented in the child test class. - It returns the child class under test""" - pass - - @abc.abstractmethod - def required_kwargs(self): - """This must be implemented in the child test class. - It returns required kwargs for the child class under test""" - pass - - # PipelineConfig objects require exactly one auth mechanism, sastoken or x509. - # For the sake of ease of testing, we will assume sastoken is being used unless - # otherwise specified. - # It does not matter which is used for the purposes of these tests. - - @pytest.fixture - def sastoken(self, mocker): - return mocker.MagicMock() - - @pytest.fixture - def x509(self, mocker): - return mocker.MagicMock() - - @pytest.mark.it( - "Instantiates with the 'hostname' attribute set to the provided 'hostname' parameter" - ) - def test_hostname_set(self, config_cls, required_kwargs, sastoken): - # Hostname is one of the required kwargs, because it is required for the child - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.hostname == required_kwargs["hostname"] - - @pytest.mark.it( - "Instantiates with the 'gateway_hostname' attribute set to the provided 'gateway_hostname' parameter" - ) - def test_gateway_hostname_set(self, config_cls, required_kwargs, sastoken): - fake_gateway_hostname = "gateway-hostname.some-domain.net" - config = config_cls( - sastoken=sastoken, gateway_hostname=fake_gateway_hostname, **required_kwargs - ) - assert config.gateway_hostname == fake_gateway_hostname - - @pytest.mark.it( - "Instantiates with the 'gateway_hostname' attribute set to 'None' if no 'gateway_hostname' parameter is provided" - ) - def test_gateway_hostname_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.gateway_hostname is None - - @pytest.mark.it( - "Instantiates with the 'keep_alive' attribute set to the provided 'keep_alive' parameter (converting the value to int)" - ) - @pytest.mark.parametrize( - "keep_alive", - [ - pytest.param(1, id="int"), - pytest.param(35.90, id="float"), - pytest.param(0b1010, id="binary"), - pytest.param(0x9, id="hexadecimal"), - pytest.param("7", id="Numeric string"), - ], - ) - def test_keep_alive_valid_with_conversion( - self, mocker, required_kwargs, config_cls, sastoken, keep_alive - ): - config = config_cls(sastoken=sastoken, keep_alive=keep_alive, **required_kwargs) - assert config.keep_alive == int(keep_alive) - - @pytest.mark.it( - "Instantiates with the 'keep_alive' attribute to 'None' if no 'keep_alive' parameter is provided" - ) - def test_keep_alive_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.keep_alive == DEFAULT_KEEPALIVE - - @pytest.mark.it("Raises TypeError if the provided 'keep_alive' parameter is not numeric") - @pytest.mark.parametrize( - "keep_alive", - [ - pytest.param("sectumsempra", id="non-numeric string"), - pytest.param((1, 2), id="tuple"), - pytest.param([1, 2], id="list"), - pytest.param(object(), id="object"), - ], - ) - def test_keep_alive_invalid_type(self, config_cls, required_kwargs, sastoken, keep_alive): - with pytest.raises(TypeError): - config_cls(sastoken=sastoken, keep_alive=keep_alive, **required_kwargs) - - @pytest.mark.it("Raises ValueError if the provided 'keep_alive' parameter has an invalid value") - @pytest.mark.parametrize( - "keep_alive", - [ - pytest.param(9876543210987654321098765432109876543210, id="> than max"), - pytest.param(-2001, id="negative"), - pytest.param(0, id="zero"), - ], - ) - def test_keep_alive_invalid_value( - self, mocker, required_kwargs, config_cls, sastoken, keep_alive - ): - with pytest.raises(ValueError): - config_cls(sastoken=sastoken, keep_alive=keep_alive, **required_kwargs) - - @pytest.mark.it( - "Instantiates with the 'sastoken' attribute set to the provided 'sastoken' parameter" - ) - def test_sastoken_set(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.sastoken is sastoken - - @pytest.mark.it( - "Instantiates with the 'sastoken' attribute set to 'None' if no 'sastoken' parameter is provided" - ) - def test_sastoken_default(self, config_cls, required_kwargs, x509): - config = config_cls(x509=x509, **required_kwargs) - assert config.sastoken is None - - @pytest.mark.it("Instantiates with the 'x509' attribute set to the provided 'x509' parameter") - def test_x509_set(self, config_cls, required_kwargs, x509): - config = config_cls(x509=x509, **required_kwargs) - assert config.x509 is x509 - - @pytest.mark.it( - "Instantiates with the 'x509' attribute set to 'None' if no 'x509 parameter is provided" - ) - def test_x509_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.x509 is None - - @pytest.mark.it( - "Raises a ValueError if neither the 'sastoken' nor 'x509' parameter is provided" - ) - def test_no_auths_provided(self, config_cls, required_kwargs): - with pytest.raises(ValueError): - config_cls(**required_kwargs) - - @pytest.mark.it("Raises a ValueError if both the 'sastoken' and 'x509' parameters are provided") - def test_both_auths_provided(self, config_cls, required_kwargs, sastoken, x509): - with pytest.raises(ValueError): - config_cls(sastoken=sastoken, x509=x509, **required_kwargs) - - @pytest.mark.it( - "Instantiates with the 'server_verification_cert' attribute set to the provided 'server_verification_cert' parameter" - ) - def test_server_verification_cert_set(self, config_cls, required_kwargs, sastoken): - svc = "fake_server_verification_cert" - config = config_cls(sastoken=sastoken, server_verification_cert=svc, **required_kwargs) - assert config.server_verification_cert == svc - - @pytest.mark.it( - "Instantiates with the 'server_verification_cert' attribute set to 'None' if no 'server_verification_cert' parameter is provided" - ) - def test_server_verification_cert_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.server_verification_cert is None - - @pytest.mark.it( - "Instantiates with the 'websockets' attribute set to the provided 'websockets' parameter" - ) - @pytest.mark.parametrize( - "websockets", [True, False], ids=["websockets == True", "websockets == False"] - ) - def test_websockets_set(self, config_cls, required_kwargs, sastoken, websockets): - config = config_cls(sastoken=sastoken, websockets=websockets, **required_kwargs) - assert config.websockets is websockets - - @pytest.mark.it( - "Instantiates with the 'websockets' attribute to 'False' if no 'websockets' parameter is provided" - ) - def test_websockets_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.websockets is False - - @pytest.mark.it( - "Instantiates with the 'cipher' attribute set to OpenSSL list formatted version of the provided 'cipher' parameter" - ) - @pytest.mark.parametrize( - "cipher_input, expected_cipher", - [ - pytest.param( - "DHE-RSA-AES128-SHA", - "DHE-RSA-AES128-SHA", - id="Single cipher suite, OpenSSL list formatted string", - ), - pytest.param( - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, OpenSSL list formatted string", - ), - pytest.param( - "DHE_RSA_AES128_SHA", - "DHE-RSA-AES128-SHA", - id="Single cipher suite, as string with '_' delimited algorithms/protocols", - ), - pytest.param( - "DHE_RSA_AES128_SHA:DHE_RSA_AES256_SHA:ECDHE_ECDSA_AES128_GCM_SHA256", - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, as string with '_' delimited algorithms/protocols and ':' delimited suites", - ), - pytest.param( - ["DHE-RSA-AES128-SHA"], - "DHE-RSA-AES128-SHA", - id="Single cipher suite, in a list, with '-' delimited algorithms/protocols", - ), - pytest.param( - ["DHE-RSA-AES128-SHA", "DHE-RSA-AES256-SHA", "ECDHE-ECDSA-AES128-GCM-SHA256"], - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, in a list, with '-' delimited algorithms/protocols", - ), - pytest.param( - ["DHE_RSA_AES128_SHA"], - "DHE-RSA-AES128-SHA", - id="Single cipher suite, in a list, with '_' delimited algorithms/protocols", - ), - pytest.param( - ["DHE_RSA_AES128_SHA", "DHE_RSA_AES256_SHA", "ECDHE_ECDSA_AES128_GCM_SHA256"], - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, in a list, with '_' delimited algorithms/protocols", - ), - ], - ) - def test_cipher(self, config_cls, required_kwargs, sastoken, cipher_input, expected_cipher): - config = config_cls(sastoken=sastoken, cipher=cipher_input, **required_kwargs) - assert config.cipher == expected_cipher - - @pytest.mark.it( - "Raises TypeError if the provided 'cipher' attribute is neither list nor string" - ) - @pytest.mark.parametrize( - "cipher", - [ - pytest.param(123, id="int"), - pytest.param( - {"cipher1": "DHE-RSA-AES128-SHA", "cipher2": "DHE_RSA_AES256_SHA"}, id="dict" - ), - pytest.param(object(), id="complex object"), - ], - ) - def test_invalid_cipher_param(self, config_cls, required_kwargs, sastoken, cipher): - with pytest.raises(TypeError): - config_cls(sastoken=sastoken, cipher=cipher, **required_kwargs) - - @pytest.mark.it( - "Instantiates with the 'cipher' attribute to empty string ('') if no 'cipher' parameter is provided" - ) - def test_cipher_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.cipher == "" - - @pytest.mark.it( - "Instantiates with the 'proxy_options' attribute set to the ProxyOptions object provided in the 'proxy_options' parameter" - ) - def test_proxy_options(self, mocker, required_kwargs, config_cls, sastoken): - proxy_options = ProxyOptions(proxy_type=1, proxy_addr="127.0.0.1", proxy_port=8888) - config = config_cls(sastoken=sastoken, proxy_options=proxy_options, **required_kwargs) - assert config.proxy_options is proxy_options - - @pytest.mark.it( - "Instantiates with the 'proxy_options' attribute to 'None' if no 'proxy_options' parameter is provided" - ) - def test_proxy_options_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.proxy_options is None - - @pytest.mark.it( - "Instantiates with the 'auto_connect' attribute set to the provided 'auto_connect' parameter" - ) - def test_auto_connect_set(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, auto_connect=False, **required_kwargs) - assert config.auto_connect is False - - @pytest.mark.it( - "Instantiates with the 'auto_connect' attribute set to 'None' if no 'auto_connect' parameter is provided" - ) - def test_auto_connect_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.auto_connect is True - - @pytest.mark.it( - "Instantiates with the 'connection_retry' attribute set to the provided 'connection_retry' parameter" - ) - def test_connection_retry_set(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, connection_retry=False, **required_kwargs) - assert config.connection_retry is False - - @pytest.mark.it( - "Instantiates with the 'connection_retry' attribute set to 'True' if no 'connection_retry' parameter is provided" - ) - def test_connection_retry_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.connection_retry is True - - @pytest.mark.it( - "Instantiates with the 'connection_retry_interval' attribute set to the provided 'connection_retry_interval' parameter (converting the value to int)" - ) - @pytest.mark.parametrize( - "connection_retry_interval", - [ - pytest.param(1, id="int"), - pytest.param(35.90, id="float"), - pytest.param(0b1010, id="binary"), - pytest.param(0x9, id="hexadecimal"), - pytest.param("7", id="Numeric string"), - ], - ) - def test_connection_retry_interval_set( - self, connection_retry_interval, config_cls, required_kwargs, sastoken - ): - config = config_cls( - sastoken=sastoken, - connection_retry_interval=connection_retry_interval, - **required_kwargs - ) - assert config.connection_retry_interval == int(connection_retry_interval) - - @pytest.mark.it( - "Instantiates with the 'connection_retry_interval' attribute set to 10 if no 'connection_retry_interval' parameter is provided" - ) - def test_connection_retry_interval_default(self, config_cls, required_kwargs, sastoken): - config = config_cls(sastoken=sastoken, **required_kwargs) - assert config.connection_retry_interval == 10 - - @pytest.mark.it( - "Raises a TypeError if the provided 'connection_retry_interval' parameter is not numeric" - ) - @pytest.mark.parametrize( - "connection_retry_interval", - [ - pytest.param("non-numeric-string", id="non-numeric string"), - pytest.param((1, 2), id="tuple"), - pytest.param([1, 2], id="list"), - pytest.param(object(), id="object"), - ], - ) - def test_connection_retry_interval_invalid_type( - self, config_cls, sastoken, required_kwargs, connection_retry_interval - ): - with pytest.raises(TypeError): - config_cls( - sastoken=sastoken, - connection_retry_interval=connection_retry_interval, - **required_kwargs - ) - - @pytest.mark.it( - "Raises a ValueError if the provided 'connection_retry_interval' parameter has an invalid value" - ) - @pytest.mark.parametrize( - "connection_retry_interval", - [ - pytest.param(threading.TIMEOUT_MAX + 1, id="> than max"), - pytest.param(-1, id="negative"), - pytest.param(0, id="zero"), - ], - ) - def test_connection_retry_interval_invalid_value( - self, config_cls, sastoken, required_kwargs, connection_retry_interval - ): - with pytest.raises(ValueError): - config_cls( - sastoken=sastoken, - connection_retry_interval=connection_retry_interval, - **required_kwargs - ) diff --git a/tests/unit/common/pipeline/fixtures.py b/tests/unit/common/pipeline/fixtures.py deleted file mode 100644 index e02dd253b..000000000 --- a/tests/unit/common/pipeline/fixtures.py +++ /dev/null @@ -1,101 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import threading -from azure.iot.device.common.pipeline import ( - pipeline_events_base, - pipeline_ops_base, - pipeline_nucleus, -) - - -class ArbitraryEvent(pipeline_events_base.PipelineEvent): - def __init__(self): - super().__init__() - - -@pytest.fixture -def arbitrary_event(): - return ArbitraryEvent() - - -class ArbitraryOperation(pipeline_ops_base.PipelineOperation): - def __init__(self, callback=None): - super().__init__(callback=callback) - - -@pytest.fixture -def arbitrary_op(mocker): - op = ArbitraryOperation(callback=mocker.MagicMock()) - mocker.spy(op, "complete") - return op - - -@pytest.fixture -def pipeline_connected_mock(mocker): - """This mock can have it's return value altered by any test to indicate whether or not the - pipeline is connected (boolean). - - Because this fixture is used by the nucleus fixture, and the nucleus is the single source of - truth for connection, changing this fixture's return value will change the connection state - of any other aspect of the pipeline (assuming it is using the nucleus fixture). - - This has to be it's own fixture, because due to how PropertyMocks work, you can't access them - on an instance of an object like you can, say, the mocked settings on a PipelineConfiguration - """ - p = mocker.PropertyMock() - return p - - -@pytest.fixture -def nucleus(mocker, pipeline_connected_mock): - """This fixture can be used to configure stages. Connection status can be mocked - via the above pipeline_connected_mock, but by default .connected will return a real value. - This nucleus will also come configured with a mocked pipeline configuration, which can be - overridden if necessary - """ - # Need to use a mock for pipeline config because we don't know - # what type of config is being used since these are common - nucleus = pipeline_nucleus.PipelineNucleus(pipeline_configuration=mocker.MagicMock()) - - # By default, set the connected mock to return the real connected value - # (this can be overridden by changing the return value of pipeline_connected_mock) - def dynamic_return(): - if not isinstance(pipeline_connected_mock.return_value, mocker.Mock): - return pipeline_connected_mock.return_value - return nucleus.connection_state is pipeline_nucleus.ConnectionState.CONNECTED - - pipeline_connected_mock.side_effect = dynamic_return - type(nucleus).connected = pipeline_connected_mock - - return nucleus - - -@pytest.fixture -def fake_pipeline_thread(): - """ - This fixture mocks out the thread name so that the pipeline decorators - use to assert that you are in a pipeline thread. - """ - this_thread = threading.current_thread() - old_name = this_thread.name - - this_thread.name = "pipeline" - yield - this_thread.name = old_name - - -@pytest.fixture -def fake_non_pipeline_thread(): - """ - This fixture sets thread name to something other than "pipeline" to force asserts - """ - this_thread = threading.current_thread() - old_name = this_thread.name - - this_thread.name = "not pipeline" - yield - this_thread.name = old_name diff --git a/tests/unit/common/pipeline/helpers.py b/tests/unit/common/pipeline/helpers.py deleted file mode 100644 index 654f1a5f0..000000000 --- a/tests/unit/common/pipeline/helpers.py +++ /dev/null @@ -1,83 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest - - -class StageRunOpTestBase(object): - """All PipelineStage .run_op() tests should inherit from this base class. - It provides basic tests for dealing with exceptions. - """ - - @pytest.mark.it( - "Completes the operation with failure if an unexpected Exception is raised while executing the operation and the operation has not yet completed" - ) - def test_completes_operation_with_error(self, mocker, stage, op, arbitrary_exception): - stage._run_op = mocker.MagicMock(side_effect=arbitrary_exception) - - stage.run_op(op) - - assert op.completed - assert op.error is arbitrary_exception - - @pytest.mark.it( - "Allows an unexpected Exception to propagate if it is raised after the operation has already been completed" - ) - def test_exception_after_op_completed(self, mocker, stage, op, arbitrary_exception): - stage._run_op = mocker.MagicMock(side_effect=arbitrary_exception) - op.completed = True - - with pytest.raises(arbitrary_exception.__class__) as e_info: - stage.run_op(op) - assert e_info.value is arbitrary_exception - - @pytest.mark.it( - "Allows any BaseException that was raised during execution of the operation to propagate" - ) - def test_base_exception_propagates(self, mocker, stage, op, arbitrary_base_exception): - stage._run_op = mocker.MagicMock(side_effect=arbitrary_base_exception) - - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - stage.run_op(op) - assert e_info.value is arbitrary_base_exception - - -class StageHandlePipelineEventTestBase(object): - """All PipelineStage .handle_pipeline_event() tests should inherit from this base class. - It provides basic tests for dealing with exceptions. - """ - - @pytest.mark.it( - "Raise any unexpected Exceptions raised during handling of the event as background exceptions, if a previous stage exists" - ) - def test_uses_background_exception_handler(self, mocker, stage, event, arbitrary_exception): - stage.previous = mocker.MagicMock() # force previous stage - stage._handle_pipeline_event = mocker.MagicMock(side_effect=arbitrary_exception) - - stage.handle_pipeline_event(event) - - assert stage.report_background_exception.call_count == 1 - assert stage.report_background_exception.call_args == mocker.call(arbitrary_exception) - - @pytest.mark.it( - "Drops any unexpected Exceptions raised during handling of the event if no previous stage exists" - ) - def test_exception_with_no_previous_stage(self, mocker, stage, event, arbitrary_exception): - stage.previous = None - stage._handle_pipeline_event = mocker.MagicMock(side_effect=arbitrary_exception) - - stage.handle_pipeline_event(event) - - assert stage.report_background_exception.call_count == 0 - # No background exception process. No errors were raised. - # Logging did also occur here, but we don't test logs - - @pytest.mark.it("Allows any BaseException raised during handling of the event to propagate") - def test_base_exception_propagates(self, mocker, stage, event, arbitrary_base_exception): - stage._handle_pipeline_event = mocker.MagicMock(side_effect=arbitrary_base_exception) - - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - stage.handle_pipeline_event(event) - assert e_info.value is arbitrary_base_exception diff --git a/tests/unit/common/pipeline/pipeline_event_test.py b/tests/unit/common/pipeline/pipeline_event_test.py deleted file mode 100644 index 3d09cf2c2..000000000 --- a/tests/unit/common/pipeline/pipeline_event_test.py +++ /dev/null @@ -1,120 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import inspect - -fake_count = 0 - -# CT-TODO: refactor this module in order to be more like pipeline_ops_test.py - - -def get_next_fake_value(): - """ - return a new "unique" fake value that can be used to test that attributes - are set correctly. Even if we expect a particular attribute to hold something - besides a string (like an array or an object), we can still test using a string, - so we do. - """ - global fake_count - fake_count = fake_count + 1 - return "__fake_value_{}__".format(fake_count) - - -base_event_defaults = {} - - -def add_event_test(cls, module, extra_defaults={}, positional_arguments=[], keyword_arguments={}): - """ - Add a test class to test the given PipelineOperation class. The class that - we're testing is passed in the cls parameter, and the different initialization - constants are passed with the named arguments that follow. - """ - all_extra_defaults = extra_defaults.copy() - all_extra_defaults.update(name=cls.__name__) - - add_instantiation_test( - cls=cls, - module=module, - defaults=base_event_defaults, - extra_defaults=all_extra_defaults, - positional_arguments=positional_arguments, - keyword_arguments=keyword_arguments, - ) - - -def add_instantiation_test( - cls, module, defaults, extra_defaults={}, positional_arguments=[], keyword_arguments={} -): - """ - internal function that takes the class and attribute details and adds a test class which - validates that the given class properly implements the given attributes. - """ - - # `defaults` contains an array of object attributes that should be set when - # we call the initializer with all of the required positional arguments - # and none of the optional keyword arguments. - - all_defaults = defaults.copy() - for key in extra_defaults: - all_defaults[key] = extra_defaults[key] - for key in keyword_arguments: - all_defaults[key] = keyword_arguments[key] - - # `args` contains an array of positional argument that we are passing to test that they - # get assigned to the correct attribute. - args = [get_next_fake_value() for i in range(len(positional_arguments))] - - # `kwargs` contains a dictionary of all keyword arguments, which includes required positional - # arguments and optional keyword arguments. - kwargs = {} - for key in positional_arguments: - kwargs[key] = get_next_fake_value() - for key in keyword_arguments: - kwargs[key] = get_next_fake_value() - - # LocalTestObject is a local class which tests the object that was passed in. pytest doesn't test - # against this local object, but it does test against it when we put it into the module namespace - # for the module that was passed in. - @pytest.mark.describe("{} - Instantiation".format(cls.__name__)) - class LocalTestObject(object): - @pytest.mark.it( - "Accepts {} positional arguments that get assigned to attributes of the same name: {}".format( - len(positional_arguments), ", ".join(positional_arguments) - ) - if len(positional_arguments) - else "Accepts no positional arguments" - ) - def test_positional_arguments(self): - instance = cls(*args) - for i in range(len(args)): - assert getattr(instance, positional_arguments[i]) == args[i] - - @pytest.mark.it( - "Accepts the following keyword arguments that get assigned to attributes of the same name: {}".format( - ", ".join(kwargs.keys()) if len(kwargs) else "None" - ) - ) - def test_keyword_arguments(self): - instance = cls(**kwargs) - for key in kwargs: - assert getattr(instance, key) == kwargs[key] - - @pytest.mark.it( - "Has the following default attributes: {}".format( - ", ".join(["{}={}".format(key, repr(all_defaults[key])) for key in all_defaults]) - ) - ) - def test_defaults(self): - instance = cls(*args) - for key in all_defaults: - if inspect.isclass(all_defaults[key]): - assert isinstance(getattr(instance, key), all_defaults[key]) - else: - assert getattr(instance, key) == all_defaults[key] - - # Adding this object to the namespace of the module that was passed in (using a name that starts with "Test") - # will cause pytest to pick it up. - setattr(module, "Test{}Instantiation".format(cls.__name__), LocalTestObject) diff --git a/tests/unit/common/pipeline/pipeline_ops_test.py b/tests/unit/common/pipeline/pipeline_ops_test.py deleted file mode 100644 index 69a646290..000000000 --- a/tests/unit/common/pipeline/pipeline_ops_test.py +++ /dev/null @@ -1,707 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import logging -from azure.iot.device.common.pipeline.pipeline_ops_base import PipelineOperation -from azure.iot.device.common.pipeline import pipeline_exceptions - -logging.basicConfig(level=logging.DEBUG) - - -def add_operation_tests( - test_module, - op_class_under_test, - op_test_config_class, - extended_op_instantiation_test_class=None, -): - """ - Add shared tests for an Operation class to a testing module. - These tests need to be done for every Operation class. - - :param test_module: A reference to the test module to add tests to - :param op_class_under_test: A reference to the specific Operation class under test - :param op_test_config_class: A class providing fixtures specific to the Operation class - under test. This class must define the following fixtures: - - "cls_type" (which returns a reference to the Operation class under test) - - "init_kwargs" (which returns a dictionary of kwargs and associated values used to - instantiate the class) - :param extended_op_instantiation_test_class: A class defining instantiation tests that are - specific to the Operation class under test, and not shared with all Operations. - Note that you may override shared instantiation tests defined in this function within - the provided test class (e.g. test_needs_connection) - """ - - # Extend the provided test config class - class OperationTestConfigClass(op_test_config_class): - @pytest.fixture - def op(self, cls_type, init_kwargs, mocker): - op = cls_type(**init_kwargs) - mocker.spy(op, "complete") - return op - - @pytest.mark.describe("{} - Instantiation".format(op_class_under_test.__name__)) - class OperationBaseInstantiationTests(OperationTestConfigClass): - @pytest.mark.it("Initializes 'name' attribute as the classname") - def test_name(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.name == op.__class__.__name__ - - @pytest.mark.it("Initializes 'completed' attribute as False") - def test_completed(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.completed is False - - @pytest.mark.it("Initializes 'completing' attribute as False") - def test_completing(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.completing is False - - @pytest.mark.it("Initializes 'error' attribute as None") - def test_error(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.error is None - - # NOTE: this test should be overridden for operations that set this value to True - @pytest.mark.it("Initializes 'needs_connection' attribute as False") - def test_needs_connection(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.needs_connection is False - - @pytest.mark.it("Initializes 'callback_stack' list attribute with the provided callback") - def test_callback_added_to_list(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert len(op.callback_stack) == 1 - assert op.callback_stack[0] is init_kwargs["callback"] - - # If an extended operation instantiation test class is provided, use those tests as well. - # By using the extended_op_instantiation_test_class as the first parent class, this ensures that - # tests from OperationBaseInstantiationTests (e.g. test_needs_connection) can be overwritten by - # tests provided in extended_op_instantiation_test_class. - if extended_op_instantiation_test_class: - - class OperationInstantiationTests( - extended_op_instantiation_test_class, OperationBaseInstantiationTests - ): - pass - - else: - - class OperationInstantiationTests(OperationBaseInstantiationTests): - pass - - @pytest.mark.describe("{} - .add_callback()".format(op_class_under_test.__name__)) - class OperationAddCallbackTests(OperationTestConfigClass): - @pytest.fixture( - params=["Currently completing with no error", "Currently completing with error"] - ) - def error(self, request, arbitrary_exception): - if request.param == "Currently completing with no error": - return None - else: - return arbitrary_exception - - @pytest.mark.it("Adds a callback to the operation's callback stack'") - def test_adds_callback(self, mocker, op): - # Because op was instantiated with a callback, because 'callback' is a - # required parameter, there will already be one callback on the stack - # before we add additional ones. - assert len(op.callback_stack) == 1 - cb1 = mocker.MagicMock() - op.add_callback(cb1) - assert len(op.callback_stack) == 2 - assert op.callback_stack[1] == cb1 - - cb2 = mocker.MagicMock() - op.add_callback(cb2) - assert len(op.callback_stack) == 3 - assert op.callback_stack[1] == cb1 - assert op.callback_stack[2] == cb2 - - @pytest.mark.it( - "Raises an OperationError if attempting to add a callback to an already-completed operation" - ) - def test_already_completed_callback(self, mocker, op): - op.complete() - assert op.completed - - with pytest.raises(pipeline_exceptions.OperationError): - op.add_callback(mocker.MagicMock()) - - @pytest.mark.it( - "Raises an OperationError if attempting to add a callback to an operation that is currently undergoing the completion process" - ) - def test_currently_completing(self, mocker, op, error): - def cb(op, error): - with pytest.raises(pipeline_exceptions.OperationError): - # Add a callback during completion of the callback, i.e. while op completion is in progress - op.add_callback(mocker.MagicMock()) - - mock_cb = mocker.MagicMock(side_effect=cb) - op.add_callback(mock_cb) - - op.complete(error) - - assert mock_cb.call_count == 1 - - @pytest.mark.describe("{} - .spawn_worker_op()".format(op_class_under_test.__name__)) - class OperationSpawnWorkerOpTests(OperationTestConfigClass): - @pytest.fixture - def worker_op_type(self): - class SomeOperationType(PipelineOperation): - def __init__(self, arg1, arg2, arg3, callback): - super().__init__(callback=callback) - - return SomeOperationType - - @pytest.fixture - def worker_op_kwargs(self): - kwargs = {"arg1": 1, "arg2": 2, "arg3": 3} - return kwargs - - @pytest.mark.it( - "Creates and returns an new instance of the Operation class specified in the 'worker_op_type' parameter" - ) - def test_returns_worker_op_instance(self, op, worker_op_type, worker_op_kwargs): - worker_op = op.spawn_worker_op(worker_op_type, **worker_op_kwargs) - assert isinstance(worker_op, worker_op_type) - - @pytest.mark.it( - "Instantiates the returned worker operation using the provided **kwargs parameters (not including 'callback')" - ) - def test_creates_worker_op_with_provided_kwargs(self, mocker, op, worker_op_kwargs): - mock_instance = mocker.MagicMock() - mock_type = mocker.MagicMock(return_value=mock_instance) - mock_type.__name__ = "mock type" # this is needed for log statements - assert "callback" not in worker_op_kwargs - - worker_op = op.spawn_worker_op(mock_type, **worker_op_kwargs) - - assert worker_op is mock_instance - assert mock_type.call_count == 1 - - # Show that all provided kwargs are used. Note that this test does NOT show that - # ONLY the provided kwargs are used - because there ARE additional kwargs added. - for kwarg in worker_op_kwargs: - assert mock_type.call_args[1][kwarg] == worker_op_kwargs[kwarg] - - @pytest.mark.it( - "Adds a secondary callback to the worker operation after instantiation, if 'callback' is included in the provided **kwargs parameters" - ) - def test_adds_callback_to_worker_op(self, mocker, op, worker_op_kwargs): - mock_instance = mocker.MagicMock() - mock_type = mocker.MagicMock(return_value=mock_instance) - mock_type.__name__ = "mock type" # this is needed for log statements - worker_op_kwargs["callback"] = mocker.MagicMock() - - worker_op = op.spawn_worker_op(mock_type, **worker_op_kwargs) - - assert worker_op is mock_instance - assert mock_type.call_count == 1 - - # The callback used for instantiating the worker operation is NOT the callback provided in **kwargs - assert mock_type.call_args[1]["callback"] is not worker_op_kwargs["callback"] - - # The callback provided in **kwargs is applied after instantiation - assert mock_instance.add_callback.call_count == 1 - assert mock_instance.add_callback.call_args == mocker.call(worker_op_kwargs["callback"]) - - @pytest.mark.it( - "Raises TypeError if the provided **kwargs parameters do not match the constructor for the class provided in the 'worker_op_type' parameter" - ) - def test_incorrect_kwargs(self, mocker, op, worker_op_type, worker_op_kwargs): - worker_op_kwargs["invalid_kwarg"] = "some value" - - with pytest.raises(TypeError): - op.spawn_worker_op(worker_op_type, **worker_op_kwargs) - - @pytest.mark.it( - "Returns a worker operation, which, when completed, completes the operation that spawned it with the same error status" - ) - @pytest.mark.parametrize( - "use_error", [pytest.param(False, id="No Error"), pytest.param(True, id="With Error")] - ) - def test_worker_op_completes_original_op( - self, mocker, use_error, arbitrary_exception, op, worker_op_type, worker_op_kwargs - ): - original_op = op - - if use_error: - error = arbitrary_exception - else: - error = None - - worker_op = original_op.spawn_worker_op(worker_op_type, **worker_op_kwargs) - assert not original_op.completed - - worker_op.complete(error=error) - - # Worker op has been completed with the given error state - assert worker_op.completed - assert worker_op.error is error - - # Original op is now completed with the same given error state - assert original_op.completed - assert original_op.error is error - - @pytest.mark.it( - "Returns a worker operation, which, when completed, triggers the 'callback' optionally provided in the **kwargs parameter, prior to completing the operation that spawned it" - ) - @pytest.mark.parametrize( - "use_error", [pytest.param(False, id="No Error"), pytest.param(True, id="With Error")] - ) - def test_worker_op_triggers_own_callback_and_then_completes_original_op( - self, mocker, use_error, arbitrary_exception, op, worker_op_type, worker_op_kwargs - ): - original_op = op - - def callback(op, error): - # Assert this callback is called before the original op begins the completion process - assert not original_op.completed - assert original_op.complete.call_count == 0 - - cb_mock = mocker.MagicMock(side_effect=callback) - - worker_op_kwargs["callback"] = cb_mock - - if use_error: - error = arbitrary_exception - else: - error = None - - worker_op = original_op.spawn_worker_op(worker_op_type, **worker_op_kwargs) - assert original_op.complete.call_count == 0 - - worker_op.complete(error=error) - - # Provided callback was called - assert cb_mock.call_count == 1 - assert cb_mock.call_args == mocker.call(op=worker_op, error=error) - - # Worker op was completed - assert worker_op.completed - - # The original op that spawned the worker is also completed - assert original_op.completed - assert original_op.complete.call_count == 1 - assert original_op.complete.call_args == mocker.call(error=error) - - @pytest.mark.describe("{} - .complete()".format(op_class_under_test.__name__)) - class OperationCompleteTests(OperationTestConfigClass): - @pytest.fixture(params=["Successful completion", "Completion with error"]) - def error(self, request, arbitrary_exception): - if request.param == "Successful completion": - return None - else: - return arbitrary_exception - - @pytest.mark.it( - "Triggers and removes callbacks from the operation's callback stack according to LIFO order, passing the operation and any error to each callback" - ) - def test_trigger_callbacks(self, mocker, cls_type, init_kwargs, error): - # Set up callback mocks - cb1_mock = mocker.MagicMock() - cb2_mock = mocker.MagicMock() - cb3_mock = mocker.MagicMock() - - def cb1(op, error): - # All callbacks have been triggered - assert cb1_mock.call_count == 1 - assert cb2_mock.call_count == 1 - assert cb3_mock.call_count == 1 - assert len(op.callback_stack) == 0 - - def cb2(op, error): - # Callback 3 and Callback 2 have been triggered, but Callback 1 has not - assert cb1_mock.call_count == 0 - assert cb2_mock.call_count == 1 - assert cb3_mock.call_count == 1 - assert len(op.callback_stack) == 1 - - def cb3(op, error): - # Callback 3 has been triggered, but no others have been. - assert cb1_mock.call_count == 0 - assert cb2_mock.call_count == 0 - assert cb3_mock.call_count == 1 - assert len(op.callback_stack) == 2 - - cb1_mock.side_effect = cb1 - cb2_mock.side_effect = cb2 - cb3_mock.side_effect = cb3 - - # Attach callbacks to op - init_kwargs["callback"] = cb1_mock - op = cls_type(**init_kwargs) - op.add_callback(cb2_mock) - op.add_callback(cb3_mock) - assert len(op.callback_stack) == 3 - assert not op.completed - - # Run the completion - op.complete(error=error) - - assert op.completed - assert cb3_mock.call_count == 1 - assert cb3_mock.call_args == mocker.call(op=op, error=error) - assert cb2_mock.call_count == 1 - assert cb2_mock.call_args == mocker.call(op=op, error=error) - assert cb1_mock.call_count == 1 - assert cb1_mock.call_args == mocker.call(op=op, error=error) - - @pytest.mark.it( - "Sets the 'error' attribute to the specified error (if any) at the beginning of the completion process" - ) - def test_sets_error(self, mocker, op, error): - original_err = error - - def cb(op, error): - # During the completion process, the 'error' attribute has been set - assert op.error is original_err - assert error is original_err - - cb_mock = mocker.MagicMock(side_effect=cb) - op.add_callback(cb_mock) - - op.complete(error=error) - - # Callback was triggered during completion - assert cb_mock.call_count == 1 - - # After the completion process, the 'error' attribute is still set - assert op.error is error - - @pytest.mark.it( - "Sets the 'completing' attribute to True only for the duration of the completion process" - ) - def test_completing_set(self, mocker, op, error): - def cb(op, error): - # The operation is completing, but not completed - assert op.completing - assert not op.completed - - cb_mock = mocker.MagicMock(side_effect=cb) - op.add_callback(cb_mock) - - op.complete(error) - - # Callback was called - assert cb_mock.call_count == 1 - - # Once completed, the op is no longer completing - assert not op.completing - assert op.completed - - @pytest.mark.it( - "Raises an OperationError if an exception is raised during execution of a callback to propagate" - ) - def test_callback_raises_exception( - self, mocker, arbitrary_exception, cls_type, init_kwargs, error - ): - # Set up callback mocks - cb1_mock = mocker.MagicMock() - cb2_mock = mocker.MagicMock(side_effect=arbitrary_exception) - cb3_mock = mocker.MagicMock() - - # Attach callbacks to op - init_kwargs["callback"] = cb1_mock - op = cls_type(**init_kwargs) - op.add_callback(cb2_mock) - op.add_callback(cb3_mock) - assert len(op.callback_stack) == 3 - - # OperationError is raised - with pytest.raises(pipeline_exceptions.OperationError) as e_info: - op.complete(error=error) - # OperationError is derived from the original exception raised - assert e_info.value.__cause__ is arbitrary_exception - - # Due to the BaseException raised during CB2 propagating, CB1 is never triggered - assert cb3_mock.call_count == 1 - assert cb2_mock.call_count == 1 - assert cb1_mock.call_count == 0 - - @pytest.mark.it( - "Leaves the operation in an uncompleted state if an Exception is raised during execution of a callback" - ) - def test_callback_exc_raised_state( - self, mocker, arbitrary_exception, cls_type, init_kwargs, error - ): - cb_mock = mocker.MagicMock(side_effect=arbitrary_exception) - - init_kwargs["callback"] = cb_mock - op = cls_type(**init_kwargs) - - # Exception from callback is raised - with pytest.raises(pipeline_exceptions.OperationError) as e_info: - op.complete(error=error) - assert e_info.value.__cause__ is arbitrary_exception - - # Completion process has been suspended. - assert not op.completed - assert not op.completing - - # No error is set - assert op.error is None - - @pytest.mark.it( - "Allows any BaseExceptions raised during execution of a callback to immediately propagate" - ) - def test_callback_raises_base_exception( - self, mocker, arbitrary_base_exception, cls_type, init_kwargs, error - ): - # Set up callback mocks - cb1_mock = mocker.MagicMock() - cb2_mock = mocker.MagicMock(side_effect=arbitrary_base_exception) - cb3_mock = mocker.MagicMock() - - # Attach callbacks to op - init_kwargs["callback"] = cb1_mock - op = cls_type(**init_kwargs) - op.add_callback(cb2_mock) - op.add_callback(cb3_mock) - assert len(op.callback_stack) == 3 - - # BaseException from callback is raised - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - op.complete(error=error) - assert e_info.value is arbitrary_base_exception - - # Due to the BaseException raised during CB2 propagating, CB1 is never triggered - assert cb3_mock.call_count == 1 - assert cb2_mock.call_count == 1 - assert cb1_mock.call_count == 0 - - # The BaseException immediately leads to propagation, no graceful adjustment of state - assert op.completing - assert not op.completed - assert op.error is error - - @pytest.mark.it( - "Halts triggering of callbacks if a callback invokes the .halt_completion() method, leaving un-triggered callbacks in the operation's callback stack" - ) - def test_halt_during_callback(self, mocker, cls_type, init_kwargs, error): - def cb2(op, error): - # Halt the operation completion as part of the callback - op.halt_completion() - - # Set up callback mocks - cb1_mock = mocker.MagicMock() - cb2_mock = mocker.MagicMock(side_effect=cb2) - cb3_mock = mocker.MagicMock() - - # Attach callbacks to op - init_kwargs["callback"] = cb1_mock - op = cls_type(**init_kwargs) - op.add_callback(cb2_mock) - op.add_callback(cb3_mock) - assert not op.completed - assert len(op.callback_stack) == 3 - - op.complete(error=error) - - # Callback was NOT completed - assert not op.completed - - # Callback resolution was halted after CB2 due to the operation completion being halted - assert cb3_mock.call_count == 1 - assert cb2_mock.call_count == 1 - assert cb1_mock.call_count == 0 - - assert len(op.callback_stack) == 1 - assert op.callback_stack[0] is cb1_mock - - @pytest.mark.it( - "Marks the operation as fully completed by setting the 'completed' attribute to True, only once all callbacks have been triggered" - ) - def test_marks_complete(self, mocker, op, error): - # Set up callback mocks - cb1_mock = mocker.MagicMock() - cb2_mock = mocker.MagicMock() - - def cb(op, error): - assert not op.completed - - cb1_mock.side_effect = cb - cb2_mock.side_effect = cb - - op.add_callback(cb1_mock) - op.add_callback(cb2_mock) - - op.complete(error=error) - assert op.completed - - # Callbacks were called - assert cb1_mock.call_count == 1 - assert cb2_mock.call_count == 1 - - @pytest.mark.it( - "Raises an OperationError, without making any changes to the operation, if the operation has already been completed" - ) - def test_already_complete(self, mocker, op, error): - # Complete the operation - op.complete(error=error) - assert op.completed - - # Get the operation state - original_op_err_state = op.error - original_op_completion_state = op.completed - - # Attempt to complete the op again - with pytest.raises(pipeline_exceptions.OperationError): - op.complete(error=error) - - # The operation state is unchanged - assert op.error is original_op_err_state - assert op.completed is original_op_completion_state - - @pytest.mark.it( - "Raises an OperationError, without making any changes to the operation, if the operation is already in the process of completing" - ) - def test_already_completing(self, mocker, op, error): - def cb(op, error): - # Get the operation state - original_op_err_state = op.error - original_op_completion_state = op.completed - - # Attempt to complete the operation again while it is already in the process of completing - with pytest.raises(pipeline_exceptions.OperationError): - op.complete(error=error) - - # The operation state is unchanged - assert op.error is original_op_err_state - assert op.completed is original_op_completion_state - - cb_mock = mocker.MagicMock(side_effect=cb) - - op.add_callback(cb_mock) - op.complete(error=error) - - @pytest.mark.it( - "Raises an OperationError if the operation is somehow completed while still undergoing the process of completion" - ) - def test_invalid_complete_during_completion(self, mocker, op, error): - # This should never happen, as this is an invalid scenario, and could only happen due - # to a bug elsewhere in the code (e.g. manually change the boolean, as in this test) - - def cb(op, error): - op.completed = True - - cb_mock = mocker.MagicMock(side_effect=cb) - - op.add_callback(cb_mock) - - with pytest.raises(pipeline_exceptions.OperationError): - op.complete(error=error) - - assert cb_mock.call_count == 1 - - @pytest.mark.it( - "Completes the operation successfully (no error) by default if no error is specified" - ) - def test_error_default(self, mocker, cls_type, init_kwargs): - cb_mock = mocker.MagicMock() - init_kwargs["callback"] = cb_mock - op = cls_type(**init_kwargs) - assert not op.completed - - op.complete() - - assert op.completed - assert op.error is None - assert cb_mock.call_count == 1 - # Callback was called passing 'None' as the error - assert cb_mock.call_args == mocker.call(op=op, error=None) - - @pytest.mark.describe("{} - .halt_completion()".format(op_class_under_test.__name__)) - class OperationHaltCompletionTests(OperationTestConfigClass): - @pytest.fixture( - params=["Currently completing with no error", "Currently completing with error"] - ) - def error(self, request, arbitrary_exception): - if request.param == "Currently completing with no error": - return None - else: - return arbitrary_exception - - @pytest.mark.it( - "Marks the operation as no longer completing by setting the 'completing' attribute to False, if the operation is currently in the process of completion" - ) - def test_sets_completing_false(self, mocker, op, error): - def cb(op, error): - assert op.completing - assert not op.completed - op.halt_completion() - assert not op.completing - - cb_mock = mocker.MagicMock(side_effect=cb) - op.add_callback(cb_mock) - - op.complete(error=error) - - assert not op.completing - assert not op.completed - assert cb_mock.call_count == 1 - - @pytest.mark.it( - "Clears the existing error in the operation's 'error' attribute, if the operation is currently in the process of completion with error" - ) - def test_clears_error(self, mocker, op, error): - completion_error = error - - def cb(op, error): - assert op.completing - assert op.error is completion_error - op.halt_completion() - assert not op.completing - assert op.error is None - - cb_mock = mocker.MagicMock(side_effect=cb) - op.add_callback(cb_mock) - - op.complete(error=completion_error) - - assert op.error is None - assert cb_mock.call_count == 1 - - @pytest.mark.it( - "Raises an OperationError if the operation has already been fully completed" - ) - def test_already_completed_op(self, mocker, op): - op.complete() - assert op.completed - - with pytest.raises(pipeline_exceptions.OperationError): - op.halt_completion() - - @pytest.mark.it( - "Sends an OperationError to the background exception handler if the operation has never been completed" - ) - def test_never_completed_op(self, mocker, op): - with pytest.raises(pipeline_exceptions.OperationError): - op.halt_completion() - - setattr( - test_module, - "Test{}Instantiation".format(op_class_under_test.__name__), - OperationInstantiationTests, - ) - setattr( - test_module, "Test{}Complete".format(op_class_under_test.__name__), OperationCompleteTests - ) - setattr( - test_module, - "Test{}AddCallback".format(op_class_under_test.__name__), - OperationAddCallbackTests, - ) - setattr( - test_module, - "Test{}HaltCompletion".format(op_class_under_test.__name__), - OperationHaltCompletionTests, - ) - setattr( - test_module, - "Test{}SpawnWorkerOp".format(op_class_under_test.__name__), - OperationSpawnWorkerOpTests, - ) diff --git a/tests/unit/common/pipeline/pipeline_stage_test.py b/tests/unit/common/pipeline/pipeline_stage_test.py deleted file mode 100644 index ec732265a..000000000 --- a/tests/unit/common/pipeline/pipeline_stage_test.py +++ /dev/null @@ -1,196 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -from tests.unit.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase -from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage -from azure.iot.device.common.pipeline import pipeline_exceptions, pipeline_events_base - -logging.basicConfig(level=logging.DEBUG) - - -def add_base_pipeline_stage_tests( - test_module, - stage_class_under_test, - stage_test_config_class, - extended_stage_instantiation_test_class=None, -): - class StageTestConfig(stage_test_config_class): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.next = mocker.MagicMock() - stage.previous = mocker.MagicMock() - mocker.spy(stage, "send_op_down") - mocker.spy(stage, "send_event_up") - mocker.spy(stage, "report_background_exception") - return stage - - ####################### - # INSTANTIATION TESTS # - ####################### - - @pytest.mark.describe("{} -- Instantiation".format(stage_class_under_test.__name__)) - class StageBaseInstantiationTests(StageTestConfig): - @pytest.mark.it("Initializes 'name' attribute as the classname") - def test_name(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.name == stage.__class__.__name__ - - @pytest.mark.it("Initializes 'next' attribute as None") - def test_next(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.next is None - - @pytest.mark.it("Initializes 'previous' attribute as None") - def test_previous(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.previous is None - - @pytest.mark.it("Initializes 'nucleus' attribute as None") - def test_pipeline_nucleus(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.nucleus is None - - if extended_stage_instantiation_test_class: - - class StageInstantiationTests( - extended_stage_instantiation_test_class, StageBaseInstantiationTests - ): - pass - - else: - - class StageInstantiationTests(StageBaseInstantiationTests): - pass - - setattr( - test_module, - "Test{}Instantiation".format(stage_class_under_test.__name__), - StageInstantiationTests, - ) - - ############## - # FLOW TESTS # - ############## - - @pytest.mark.describe("{} - .send_op_down()".format(stage_class_under_test.__name__)) - class StageSendOpDownTests(StageTestConfig): - @pytest.mark.it("Passes the op to the next stage's .run_op() method") - def test_passes_op_to_next_stage(self, mocker, stage, arbitrary_op): - stage.send_op_down(arbitrary_op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args == mocker.call(arbitrary_op) - - @pytest.mark.it("Raises a PipelineRuntimeError if there is no next stage") - def test_fails_op_when_no_next_stage(self, mocker, stage, arbitrary_op): - stage.next = None - - with pytest.raises(pipeline_exceptions.PipelineRuntimeError): - stage.send_op_down(arbitrary_op) - - @pytest.mark.describe("{} - .send_event_up()".format(stage_class_under_test.__name__)) - class StageSendEventUpTests(StageTestConfig): - @pytest.mark.it( - "Passes the event up to the previous stage's .handle_pipeline_event() method" - ) - def test_calls_handle_pipeline_event(self, stage, arbitrary_event, mocker): - stage.send_event_up(arbitrary_event) - assert stage.previous.handle_pipeline_event.call_count == 1 - assert stage.previous.handle_pipeline_event.call_args == mocker.call(arbitrary_event) - - @pytest.mark.it("Raises a PipelineRuntimeError if there is no previous stage") - def test_no_previous_stage(self, stage, arbitrary_event): - stage.previous = None - with pytest.raises(pipeline_exceptions.PipelineRuntimeError): - stage.send_event_up(arbitrary_event) - - setattr( - test_module, - "Test{}SendOpDown".format(stage_class_under_test.__name__), - StageSendOpDownTests, - ) - setattr( - test_module, - "Test{}SendEventUp".format(stage_class_under_test.__name__), - StageSendEventUpTests, - ) - - ############################################# - # RUN OP / HANDLE_PIPELINE_EVENT BASE TESTS # - ############################################# - - # These tests are only run if the Stage in question has NOT overridden the PipelineStage base - # implementations of ._run_op() and/or ._handle_pipeline_event() - - if stage_class_under_test._run_op is PipelineStage._run_op: - - @pytest.mark.describe("{} - .run_op()".format(stage_class_under_test.__name__)) - class StageRunOpUnhandledOp(StageTestConfig, StageRunOpTestBase): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_passes_down(self, mocker, stage, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - setattr( - test_module, - "Test{}RunOpUnhandledOp".format(stage_class_under_test.__name__), - StageRunOpUnhandledOp, - ) - - if stage_class_under_test._handle_pipeline_event is PipelineStage._handle_pipeline_event: - - @pytest.mark.describe( - "{} - .handle_pipeline_event()".format(stage_class_under_test.__name__) - ) - class StageHandlePipelineEventUnhandledEvent( - StageTestConfig, StageHandlePipelineEventTestBase - ): - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it("Sends the event up the pipeline") - def test_passes_up(self, mocker, stage, event): - stage.handle_pipeline_event(event) - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - setattr( - test_module, - "Test{}HandlePipelineEventUnhandledEvent".format(stage_class_under_test.__name__), - StageHandlePipelineEventUnhandledEvent, - ) - - ############### - # OTHER TESTS # - ############### - - @pytest.mark.describe( - "{} - .report_background_exception()".format(stage_class_under_test.__name__) - ) - class StageRaiseBackgroundExceptionTests(StageTestConfig): - @pytest.mark.it( - "Sends a BackgroundExceptionEvent up the pipeline with the provided exception set on it" - ) - def test_new_event_sent_up(self, mocker, stage, arbitrary_exception): - stage.report_background_exception(arbitrary_exception) - - assert stage.send_event_up.call_count == 1 - event = stage.send_event_up.call_args[0][0] - assert isinstance(event, pipeline_events_base.BackgroundExceptionEvent) - assert event.e is arbitrary_exception - - setattr( - test_module, - "Test{}RaiseBackgroundException".format(stage_class_under_test.__name__), - StageRaiseBackgroundExceptionTests, - ) diff --git a/tests/unit/common/pipeline/test_pipeline_events_base.py b/tests/unit/common/pipeline/test_pipeline_events_base.py deleted file mode 100644 index 36252a47a..000000000 --- a/tests/unit/common/pipeline/test_pipeline_events_base.py +++ /dev/null @@ -1,57 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import sys -import pytest -import logging -from azure.iot.device.common.pipeline import pipeline_events_base -from tests.unit.common.pipeline import pipeline_event_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] - - -@pytest.mark.describe("PipelineEvent") -class TestPipelineOperation(object): - @pytest.mark.it("Can't be instantiated") - def test_instantiate(self): - with pytest.raises(TypeError): - pipeline_events_base.PipelineEvent() - - -pipeline_event_test.add_event_test( - cls=pipeline_events_base.ResponseEvent, - module=this_module, - positional_arguments=["request_id", "status_code", "response_body"], - keyword_arguments={}, -) - -pipeline_event_test.add_event_test( - cls=pipeline_events_base.ConnectedEvent, - module=this_module, - positional_arguments=[], - keyword_arguments={}, -) - -pipeline_event_test.add_event_test( - cls=pipeline_events_base.DisconnectedEvent, - module=this_module, - positional_arguments=[], - keyword_arguments={}, -) - -pipeline_event_test.add_event_test( - cls=pipeline_events_base.NewSasTokenRequiredEvent, - module=this_module, - positional_arguments=[], - keyword_arguments={}, -) - -pipeline_event_test.add_event_test( - cls=pipeline_events_base.BackgroundExceptionEvent, - module=this_module, - positional_arguments=["e"], - keyword_arguments={}, -) diff --git a/tests/unit/common/pipeline/test_pipeline_events_mqtt.py b/tests/unit/common/pipeline/test_pipeline_events_mqtt.py deleted file mode 100644 index 24b86fb3d..000000000 --- a/tests/unit/common/pipeline/test_pipeline_events_mqtt.py +++ /dev/null @@ -1,20 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import sys -import logging -from azure.iot.device.common.pipeline import pipeline_events_mqtt -from tests.unit.common.pipeline import pipeline_event_test - -logging.basicConfig(level=logging.DEBUG) - -this_module = sys.modules[__name__] - -pipeline_event_test.add_event_test( - cls=pipeline_events_mqtt.IncomingMQTTMessageEvent, - module=this_module, - positional_arguments=["topic", "payload"], - keyword_arguments={}, -) diff --git a/tests/unit/common/pipeline/test_pipeline_nucleus.py b/tests/unit/common/pipeline/test_pipeline_nucleus.py deleted file mode 100644 index 3d89c3a34..000000000 --- a/tests/unit/common/pipeline/test_pipeline_nucleus.py +++ /dev/null @@ -1,62 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import logging -from azure.iot.device.common.pipeline.pipeline_nucleus import PipelineNucleus, ConnectionState - - -logging.basicConfig(level=logging.DEBUG) - - -@pytest.mark.describe("PipelineNucleus - Instantiation") -class TestPipelineNucleusInstantiation(object): - @pytest.fixture - def pipeline_config(self, mocker): - return mocker.MagicMock() - - @pytest.mark.it( - "Instantiates with the 'pipeline_configuration' attribute set the the value of the provided 'pipeline_configuration' parameter" - ) - def test_pipeline_config(self, pipeline_config): - nucleus = PipelineNucleus(pipeline_configuration=pipeline_config) - assert nucleus.pipeline_configuration is pipeline_config - - @pytest.mark.it("Instantiates with the 'connection_state' attribute set to DISCONNECTED") - def test_connected(self, pipeline_config): - nucleus = PipelineNucleus(pipeline_config) - assert nucleus.connection_state is ConnectionState.DISCONNECTED - - -@pytest.mark.describe("PipelineNucleus - PROPERTY .connected") -class TestPipelineNucleusPROPERTYConnected(object): - @pytest.fixture - def nucleus(self, mocker): - pl_cfg = mocker.MagicMock() - return PipelineNucleus(pl_cfg) - - @pytest.mark.it("Is a read-only property") - def test_read_only(self, nucleus): - with pytest.raises(AttributeError): - nucleus.connected = False - - @pytest.mark.it("Returns True if the '.connection_state' attribute is CONNECTED") - def test_connected(self, nucleus): - nucleus.connection_state = ConnectionState.CONNECTED - assert nucleus.connected - - @pytest.mark.it("Returns False if the '.connection_state' attribute has any other value") - @pytest.mark.parametrize( - "state", - [ - ConnectionState.DISCONNECTED, - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTING, - ConnectionState.REAUTHORIZING, - ], - ) - def test_not_connected(self, nucleus, state): - nucleus.connection_state = state - assert not nucleus.connected diff --git a/tests/unit/common/pipeline/test_pipeline_ops_base.py b/tests/unit/common/pipeline/test_pipeline_ops_base.py deleted file mode 100644 index 1e35351fa..000000000 --- a/tests/unit/common/pipeline/test_pipeline_ops_base.py +++ /dev/null @@ -1,292 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import sys -import pytest -import logging -from azure.iot.device.common.pipeline import pipeline_ops_base -from tests.unit.common.pipeline import pipeline_ops_test - -this_module = sys.modules[__name__] -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -class InitializePipelineOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.InitializePipelineOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"callback": mocker.MagicMock()} - return kwargs - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.InitializePipelineOperation, - op_test_config_class=InitializePipelineOperationTestConfig, -) - - -class ShutdownPipelineOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.ShutdownPipelineOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"callback": mocker.MagicMock()} - return kwargs - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.ShutdownPipelineOperation, - op_test_config_class=ShutdownPipelineOperationTestConfig, -) - - -class ConnectOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.ConnectOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"callback": mocker.MagicMock()} - return kwargs - - -class ConnectOperationInstantiationTests(ConnectOperationTestConfig): - @pytest.mark.it("Initializes 'watchdog_timer' attribute to 'None'") - def test_retry_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.watchdog_timer is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.ConnectOperation, - op_test_config_class=ConnectOperationTestConfig, - extended_op_instantiation_test_class=ConnectOperationInstantiationTests, -) - - -class DisconnectOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.DisconnectOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"callback": mocker.MagicMock()} - return kwargs - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.DisconnectOperation, - op_test_config_class=DisconnectOperationTestConfig, -) - - -class ReauthorizeConnectionOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.ReauthorizeConnectionOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"callback": mocker.MagicMock()} - return kwargs - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.ReauthorizeConnectionOperation, - op_test_config_class=ReauthorizeConnectionOperationTestConfig, -) - - -class EnableFeatureOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.EnableFeatureOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"feature_name": "some_feature", "callback": mocker.MagicMock()} - return kwargs - - -class EnableFeatureInstantiationTests(EnableFeatureOperationTestConfig): - @pytest.mark.it( - "Initializes 'feature_name' attribute with the provided 'feature_name' parameter" - ) - def test_feature_name(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.feature_name == init_kwargs["feature_name"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.EnableFeatureOperation, - op_test_config_class=EnableFeatureOperationTestConfig, - extended_op_instantiation_test_class=EnableFeatureInstantiationTests, -) - - -class DisableFeatureOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.DisableFeatureOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"feature_name": "some_feature", "callback": mocker.MagicMock()} - return kwargs - - -class DisableFeatureInstantiationTests(DisableFeatureOperationTestConfig): - @pytest.mark.it( - "Initializes 'feature_name' attribute with the provided 'feature_name' parameter" - ) - def test_feature_name(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.feature_name == init_kwargs["feature_name"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.DisableFeatureOperation, - op_test_config_class=DisableFeatureOperationTestConfig, - extended_op_instantiation_test_class=DisableFeatureInstantiationTests, -) - - -class RequestAndResponseOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.RequestAndResponseOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "request_type": "some_request_type", - "method": "SOME_METHOD", - "resource_location": "some/resource/location", - "request_body": "some_request_body", - "callback": mocker.MagicMock(), - } - return kwargs - - -class RequestAndResponseOperationInstantiationTests(RequestAndResponseOperationTestConfig): - @pytest.mark.it( - "Initializes 'request_type' attribute with the provided 'request_type' parameter" - ) - def test_request_type(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_type == init_kwargs["request_type"] - - @pytest.mark.it("Initializes 'method' attribute with the provided 'method' parameter") - def test_method_type(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.method == init_kwargs["method"] - - @pytest.mark.it( - "Initializes 'resource_location' attribute with the provided 'resource_location' parameter" - ) - def test_resource_location(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.resource_location == init_kwargs["resource_location"] - - @pytest.mark.it( - "Initializes 'request_body' attribute with the provided 'request_body' parameter" - ) - def test_request_body(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_body == init_kwargs["request_body"] - - @pytest.mark.it("Initializes 'status_code' attribute to None") - def test_status_code(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.status_code is None - - @pytest.mark.it("Initializes 'response_body' attribute to None") - def test_response_body(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.response_body is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.RequestAndResponseOperation, - op_test_config_class=RequestAndResponseOperationTestConfig, - extended_op_instantiation_test_class=RequestAndResponseOperationInstantiationTests, -) - - -class RequestOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.RequestOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "method": "SOME_METHOD", - "resource_location": "some/resource/location", - "request_type": "some_request_type", - "request_body": "some_request_body", - "request_id": "some_request_id", - "callback": mocker.MagicMock(), - } - return kwargs - - -class RequestOperationInstantiationTests(RequestOperationTestConfig): - @pytest.mark.it("Initializes the 'method' attribute with the provided 'method' parameter") - def test_method(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.method == init_kwargs["method"] - - @pytest.mark.it( - "Initializes the 'resource_location' attribute with the provided 'resource_location' parameter" - ) - def test_resource_location(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.resource_location == init_kwargs["resource_location"] - - @pytest.mark.it( - "Initializes the 'request_type' attribute with the provided 'request_type' parameter" - ) - def test_request_type(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_type == init_kwargs["request_type"] - - @pytest.mark.it( - "Initializes the 'request_body' attribute with the provided 'request_body' parameter" - ) - def test_request_body(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_body == init_kwargs["request_body"] - - @pytest.mark.it( - "Initializes the 'request_id' attribute with the provided 'request_id' parameter" - ) - def test_request_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_id == init_kwargs["request_id"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.RequestOperation, - op_test_config_class=RequestOperationTestConfig, - extended_op_instantiation_test_class=RequestOperationInstantiationTests, -) diff --git a/tests/unit/common/pipeline/test_pipeline_ops_http.py b/tests/unit/common/pipeline/test_pipeline_ops_http.py deleted file mode 100644 index 357cb163d..000000000 --- a/tests/unit/common/pipeline/test_pipeline_ops_http.py +++ /dev/null @@ -1,84 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import sys -import logging -from azure.iot.device.common.pipeline import pipeline_ops_http -from tests.unit.common.pipeline import pipeline_ops_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -class HTTPRequestAndResponseOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_http.HTTPRequestAndResponseOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "method": "some_topic", - "path": "some_path", - "headers": {"some_key": "some_value"}, - "body": "some_body", - "query_params": "some_query_params", - "callback": mocker.MagicMock(), - } - return kwargs - - -class HTTPRequestAndResponseOperationInstantiationTests(HTTPRequestAndResponseOperationTestConfig): - @pytest.mark.it("Initializes 'method' attribute with the provided 'method' parameter") - def test_method(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.method == init_kwargs["method"] - - @pytest.mark.it("Initializes 'path' attribute with the provided 'path' parameter") - def test_path(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.path == init_kwargs["path"] - - @pytest.mark.it("Initializes 'headers' attribute with the provided 'headers' parameter") - def test_headers(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.headers == init_kwargs["headers"] - - @pytest.mark.it("Initializes 'body' attribute with the provided 'body' parameter") - def test_body(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.body == init_kwargs["body"] - - @pytest.mark.it( - "Initializes 'query_params' attribute with the provided 'query_params' parameter" - ) - def test_query_params(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.query_params == init_kwargs["query_params"] - - @pytest.mark.it("Initializes 'status_code' attribute as None") - def test_status_code(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.status_code is None - - @pytest.mark.it("Initializes 'response_body' attribute as None") - def test_response_body(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.response_body is None - - @pytest.mark.it("Initializes 'reason' attribute as None") - def test_reason(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.reason is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_http.HTTPRequestAndResponseOperation, - op_test_config_class=HTTPRequestAndResponseOperationTestConfig, - extended_op_instantiation_test_class=HTTPRequestAndResponseOperationInstantiationTests, -) diff --git a/tests/unit/common/pipeline/test_pipeline_ops_mqtt.py b/tests/unit/common/pipeline/test_pipeline_ops_mqtt.py deleted file mode 100644 index e138ce93d..000000000 --- a/tests/unit/common/pipeline/test_pipeline_ops_mqtt.py +++ /dev/null @@ -1,132 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import sys -import logging -from azure.iot.device.common.pipeline import pipeline_ops_mqtt -from tests.unit.common.pipeline import pipeline_ops_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -class MQTTPublishOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_mqtt.MQTTPublishOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"topic": "some_topic", "payload": "some_payload", "callback": mocker.MagicMock()} - return kwargs - - -class MQTTPublishOperationInstantiationTests(MQTTPublishOperationTestConfig): - @pytest.mark.it("Initializes 'topic' attribute with the provided 'topic' parameter") - def test_topic(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.topic == init_kwargs["topic"] - - @pytest.mark.it("Initializes 'payload' attribute with the provided 'payload' parameter") - def test_payload(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.payload == init_kwargs["payload"] - - @pytest.mark.it("Initializes 'needs_connection' attribute as True") - def test_needs_connection(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.needs_connection is True - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_mqtt.MQTTPublishOperation, - op_test_config_class=MQTTPublishOperationTestConfig, - extended_op_instantiation_test_class=MQTTPublishOperationInstantiationTests, -) - - -class MQTTSubscribeOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_mqtt.MQTTSubscribeOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"topic": "some_topic", "callback": mocker.MagicMock()} - return kwargs - - -class MQTTSubscribeOperationInstantiationTests(MQTTSubscribeOperationTestConfig): - @pytest.mark.it("Initializes 'topic' attribute with the provided 'topic' parameter") - def test_topic(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.topic == init_kwargs["topic"] - - @pytest.mark.it("Initializes 'needs_connection' attribute as True") - def test_needs_connection(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.needs_connection is True - - @pytest.mark.it("Initializes 'timeout_timer' attribute as None") - def test_timeout_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.timeout_timer is None - - @pytest.mark.it("Initializes 'retry_timer' attribute as None") - def test_retry_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.retry_timer is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_mqtt.MQTTSubscribeOperation, - op_test_config_class=MQTTSubscribeOperationTestConfig, - extended_op_instantiation_test_class=MQTTSubscribeOperationInstantiationTests, -) - - -class MQTTUnsubscribeOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_mqtt.MQTTUnsubscribeOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"topic": "some_topic", "callback": mocker.MagicMock()} - return kwargs - - -class MQTTUnsubscribeOperationInstantiationTests(MQTTUnsubscribeOperationTestConfig): - @pytest.mark.it("Initializes 'topic' attribute with the provided 'topic' parameter") - def test_topic(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.topic == init_kwargs["topic"] - - @pytest.mark.it("Initializes 'needs_connection' attribute as True") - def test_needs_connection(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.needs_connection is True - - @pytest.mark.it("Initializes 'timeout_timer' attribute as None") - def test_timeout_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.timeout_timer is None - - @pytest.mark.it("Initializes 'retry_timer' attribute as None") - def test_retry_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.retry_timer is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_mqtt.MQTTUnsubscribeOperation, - op_test_config_class=MQTTUnsubscribeOperationTestConfig, - extended_op_instantiation_test_class=MQTTUnsubscribeOperationInstantiationTests, -) diff --git a/tests/unit/common/pipeline/test_pipeline_stages_base.py b/tests/unit/common/pipeline/test_pipeline_stages_base.py deleted file mode 100644 index a5374bab3..000000000 --- a/tests/unit/common/pipeline/test_pipeline_stages_base.py +++ /dev/null @@ -1,3885 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import copy -import time -import pytest -import sys -import threading -import random -import uuid -import queue -from azure.iot.device.common import transport_exceptions, alarm -from azure.iot.device.common.auth import sastoken as st -from azure.iot.device.common.pipeline import ( - pipeline_stages_base, - pipeline_ops_base, - pipeline_ops_mqtt, - pipeline_events_base, - pipeline_exceptions, -) - -# I normally try to keep my imports in tests at module level, but it's just too unwieldy w/ ConnectionState -from azure.iot.device.common.pipeline.pipeline_nucleus import ConnectionState -from .helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase -from .fixtures import ArbitraryOperation -from tests.unit.common.pipeline import pipeline_stage_test - - -this_module = sys.modules[__name__] -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -fake_signed_data = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" -fake_uri = "some/resource/location" -fake_current_time = 10000000000 -fake_expiry = 10000003600 - - -################### -# COMMON FIXTURES # -################### -@pytest.fixture -def mock_timer(mocker): - return mocker.patch.object(threading, "Timer") - - -@pytest.fixture -def mock_alarm(mocker): - return mocker.patch.object(alarm, "Alarm") - - -@pytest.fixture(autouse=True) -def mock_time(mocker): - # Need to ALWAYS mock current time - time_mock = mocker.patch.object(time, "time") - time_mock.return_value = fake_current_time - - -# Not a fixture, but useful for sharing -def fake_callback(*args, **kwargs): - pass - - -####################### -# PIPELINE ROOT STAGE # -####################### - - -class PipelineRootStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_base.PipelineRootStage - - @pytest.fixture - def init_kwargs(self, nucleus): - return {"nucleus": nucleus} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class PipelineRootStageInstantiationTests(PipelineRootStageTestConfig): - @pytest.mark.it("Initializes 'on_pipeline_event_handler' as None") - def test_on_pipeline_event_handler(self, init_kwargs): - stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) - assert stage.on_pipeline_event_handler is None - - @pytest.mark.it("Initializes 'on_connected_handler' as None") - def test_on_connected_handler(self, init_kwargs): - stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) - assert stage.on_connected_handler is None - - @pytest.mark.it("Initializes 'on_disconnected_handler' as None") - def test_on_disconnected_handler(self, init_kwargs): - stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) - assert stage.on_disconnected_handler is None - - @pytest.mark.it("Initializes 'on_new_sastoken_required_handler' as None") - def test_on_new_sastoken_required_handler(self, init_kwargs): - stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) - assert stage.on_new_sastoken_required_handler is None - - @pytest.mark.it("Initializes 'on_background_exception_handler' as None") - def test_on_background_exception_handler(self, init_kwargs): - stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) - assert stage.on_background_exception_handler is None - - @pytest.mark.it("Initializes 'nucleus' with the provided 'nucleus' parameter") - def test_pipeline_nucleus(self, init_kwargs): - stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) - assert stage.nucleus is init_kwargs["nucleus"] - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_base.PipelineRootStage, - stage_test_config_class=PipelineRootStageTestConfig, - extended_stage_instantiation_test_class=PipelineRootStageInstantiationTests, -) - - -@pytest.mark.describe("PipelineRootStage - .append_stage()") -class TestPipelineRootStageAppendStage(PipelineRootStageTestConfig): - @pytest.mark.it("Appends the provided stage to the tail of the pipeline") - @pytest.mark.parametrize( - "pipeline_len", - [ - pytest.param(1, id="Pipeline Length: 1"), - pytest.param(2, id="Pipeline Length: 2"), - pytest.param(3, id="Pipeline Length: 3"), - pytest.param(10, id="Pipeline Length: 10"), - pytest.param(random.randint(4, 99), id="Randomly chosen Pipeline Length"), - ], - ) - def test_appends_new_stage(self, stage, pipeline_len): - class ArbitraryStage(pipeline_stages_base.PipelineStage): - pass - - assert stage.next is None - assert stage.previous is None - prev_tail = stage - root = stage - for _ in range(0, pipeline_len): - new_stage = ArbitraryStage() - stage.append_stage(new_stage) - assert prev_tail.next is new_stage - assert new_stage.previous is prev_tail - assert new_stage.nucleus is root.nucleus - prev_tail = new_stage - - -# NOTE 1: Because the Root stage overrides the parent implementation, we must test it here -# (even though it's the same test). -# NOTE 2: Currently this implementation does some other things with threads, but we do not -# currently have a thread testing strategy, so it is untested for now. -@pytest.mark.describe("PipelineRootStage - .run_op()") -class TestPipelineRootStageRunOp(PipelineRootStageTestConfig): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe("PipelineRootStage - .handle_pipeline_event() -- Called with ConnectedEvent") -class TestPipelineRootStageHandlePipelineEventWithConnectedEvent( - PipelineRootStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self): - return pipeline_events_base.ConnectedEvent() - - @pytest.mark.it("Invokes the 'on_connected_handler' handler function, if set") - def test_invoke_handler(self, mocker, stage, event): - mock_handler = mocker.MagicMock() - stage.on_connected_handler = mock_handler - stage.handle_pipeline_event(event) - time.sleep(0.1) # Needs a brief sleep so thread can switch - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call() - - -@pytest.mark.describe( - "PipelineRootStage - .handle_pipeline_event() -- Called with DisconnectedEvent" -) -class TestPipelineRootStageHandlePipelineEventWithDisconnectedEvent( - PipelineRootStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self): - return pipeline_events_base.DisconnectedEvent() - - @pytest.mark.it("Invokes the 'on_disconnected_handler' handler function, if set") - def test_invoke_handler(self, mocker, stage, event): - mock_handler = mocker.MagicMock() - stage.on_disconnected_handler = mock_handler - stage.handle_pipeline_event(event) - time.sleep(0.1) # Needs a brief sleep so thread can switch - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call() - - -@pytest.mark.describe( - "PipelineRootStage - .handle_pipeline_event() -- Called with NewSasTokenRequiredEvent" -) -class TestPipelineRootStageHandlePipelineEventWithNewSasTokenRequiredEvent( - PipelineRootStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self): - return pipeline_events_base.NewSasTokenRequiredEvent() - - @pytest.mark.it("Invokes the 'on_new_sastoken_required_handler' handler function, if set") - def test_invoke_handler(self, mocker, stage, event): - mock_handler = mocker.MagicMock() - stage.on_new_sastoken_required_handler = mock_handler - stage.handle_pipeline_event(event) - time.sleep(0.1) # Needs a brief sleep so thread can switch - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call() - - -@pytest.mark.describe( - "PipelineRootStage - .handle_pipeline_event() -- Called with BackgroundExceptionEvent" -) -class TestPipelineRootStageHandlePipelineEventWithBackgroundExceptionEvent( - PipelineRootStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self, arbitrary_exception): - return pipeline_events_base.BackgroundExceptionEvent(arbitrary_exception) - - @pytest.mark.it( - "Invokes the 'on_background_exception_handler' handler function, passing the exception object, if set" - ) - def test_invoke_handler(self, mocker, stage, event): - mock_handler = mocker.MagicMock() - stage.on_background_exception_handler = mock_handler - stage.handle_pipeline_event(event) - time.sleep(0.1) # Needs a brief sleep so thread can switch - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(event.e) - - -@pytest.mark.describe( - "PipelineRootStage - .handle_pipeline_event() -- Called with an arbitrary other event" -) -class TestPipelineRootStageHandlePipelineEventWithArbitraryEvent( - PipelineRootStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it("Invokes the 'on_pipeline_event_handler' handler function, if set") - def test_invoke_handler(self, mocker, stage, event): - mock_handler = mocker.MagicMock() - stage.on_pipeline_event_handler = mock_handler - stage.handle_pipeline_event(event) - time.sleep(0.1) # Needs a brief sleep so thread can switch - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(event) - - -################### -# SAS TOKEN STAGE # -################### - - -class SasTokenStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_base.SasTokenStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, nucleus, sastoken, init_kwargs): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.sastoken = sastoken - stage.nucleus.pipeline_configuration.connection_retry_interval = 1234 - # Mock flow methods - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - stage.report_background_exception = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class SasTokenStageInstantiationTests(SasTokenStageTestConfig): - @pytest.mark.it("Initializes with the token update alarm set to 'None'") - def test_token_update_timer(self, init_kwargs): - stage = pipeline_stages_base.SasTokenStage(**init_kwargs) - assert stage._token_update_alarm is None - - @pytest.mark.it("Initializes with the reauth retry timer set to 'None'") - def test_reauth_retry_timer(self, init_kwargs): - stage = pipeline_stages_base.SasTokenStage(**init_kwargs) - assert stage._reauth_retry_timer is None - - @pytest.mark.it("Uses 120 seconds as the Update Margin by default") - def test_update_margin(self, init_kwargs): - # NOTE: currently, update margin isn't set as an instance attribute really, it just uses - # a constant defined on the class in all cases. Eventually this logic may be expanded to - # be more dynamic, and this test will need to change - stage = pipeline_stages_base.SasTokenStage(**init_kwargs) - assert stage.DEFAULT_TOKEN_UPDATE_MARGIN == 120 - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_base.SasTokenStage, - stage_test_config_class=SasTokenStageTestConfig, - extended_stage_instantiation_test_class=SasTokenStageInstantiationTests, -) - - -@pytest.mark.describe( - "SasTokenStage - .run_op() -- Called with InitializePipelineOperation (Pipeline configured for SAS authentication)" -) -class TestSasTokenStageRunOpWithInitializePipelineOpSasTokenConfig( - SasTokenStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.fixture(params=["Renewable SAS Authentication", "Non-renewable SAS Authentication"]) - def sastoken(self, mocker, request): - if request.param == "Renewable SAS Authentication": - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = fake_signed_data - sastoken = st.RenewableSasToken(uri=fake_uri, signing_mechanism=mock_signing_mechanism) - sastoken.refresh = mocker.MagicMock() - else: - token_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=fake_uri, signature=fake_signed_data, expiry=fake_expiry - ) - sastoken = st.NonRenewableSasToken(token_str) - return sastoken - - @pytest.mark.it("Cancels any existing token update alarm that may have been set") - def test_cancels_existing_alarm(self, mocker, mock_alarm, stage, op): - stage._token_update_alarm = mock_alarm - - stage.run_op(op) - - assert mock_alarm.cancel.call_count == 1 - assert mock_alarm.cancel.call_args == mocker.call() - - @pytest.mark.it("Resets the token update alarm to None until a new one is set") - # Edge case, since unless something goes wrong, the alarm WILL be set, and it's like - # it was never set to None. - def test_alarm_set_to_none_in_intermediate( - self, mocker, stage, op, mock_alarm, arbitrary_exception - ): - # Set an existing alarm - stage._token_update_alarm = mocker.MagicMock() - - # Set an error side effect on the alarm creation, so when a new alarm is created, - # we have an unhandled error causing op failure and early exit - mock_alarm.side_effect = arbitrary_exception - - stage.run_op(op) - - assert op.complete - assert op.error is arbitrary_exception - assert stage._token_update_alarm is None - - @pytest.mark.it( - "Starts a background update alarm that will trigger 'Update Margin' number of seconds prior to SasToken expiration" - ) - def test_sets_alarm(self, mocker, stage, op, mock_alarm): - expected_alarm_time = ( - stage.nucleus.pipeline_configuration.sastoken.expiry_time - - pipeline_stages_base.SasTokenStage.DEFAULT_TOKEN_UPDATE_MARGIN - ) - - stage.run_op(op) - - assert mock_alarm.call_count == 1 - assert mock_alarm.call_args[0][0] == expected_alarm_time - assert mock_alarm.return_value.daemon is True - assert mock_alarm.return_value.start.call_count == 1 - assert mock_alarm.return_value.start.call_args == mocker.call() - - @pytest.mark.it( - "Starts a background update alarm that will instead trigger after MAX_TIMEOUT seconds if the SasToken expiration time (less the Update Margin) is more than MAX_TIMEOUT seconds in the future" - ) - def test_sets_alarm_long_expiration(self, mocker, stage, op, mock_alarm): - token = stage.nucleus.pipeline_configuration.sastoken - new_expiry = token.expiry_time + threading.TIMEOUT_MAX - if isinstance(token, st.RenewableSasToken): - token._expiry_time = new_expiry - else: - token._token_info["se"] = new_expiry - # NOTE: time.time is implicitly mocked to return a constant test value here - expected_alarm_time = time.time() + threading.TIMEOUT_MAX - assert expected_alarm_time < token.expiry_time - - stage.run_op(op) - - assert mock_alarm.call_count == 1 - assert mock_alarm.call_args[0][0] == expected_alarm_time - assert mock_alarm.return_value.daemon is True - assert mock_alarm.return_value.start.call_count == 1 - assert mock_alarm.return_value.start.call_args == mocker.call() - - -@pytest.mark.describe( - "SasTokenStage - .run_op() -- Called with InitializePipelineOperation (Pipeline not configured for SAS authentication)" -) -class TestSasTokenStageRunOpWithInitializePipelineOpNoSasTokenConfig( - SasTokenStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.fixture - def sastoken(self): - return None - - @pytest.mark.it("Sends the operation down, WITHOUT setting a update alarm") - def test_sends_op_down_no_alarm(self, mocker, stage, mock_alarm, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - assert stage._token_update_alarm is None - assert mock_alarm.call_count == 0 - - -@pytest.mark.describe( - "SasTokenStage - .run_op() -- Called with ReauthorizeConnectionOperation (Pipeline configured for SAS authentication)" -) -class TestSasTokenStageRunOpWithReauthorizeConnectionOperationPipelineOpSasTokenConfig( - SasTokenStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) - - # NOTE: We test both renewable and non-renewable here for safety, but in practice, this will - # only ever be for non-renewable tokens due to how the client forms the pipeline. A - # ReauthorizeConnectionOperation that appears this high in the pipeline could only be created - # in the case of non-renewable SAS flow. - @pytest.fixture(params=["Renewable SAS Authentication", "Non-renewable SAS Authentication"]) - def sastoken(self, mocker, request): - if request.param == "Renewable SAS Authentication": - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = fake_signed_data - sastoken = st.RenewableSasToken(uri=fake_uri, signing_mechanism=mock_signing_mechanism) - sastoken.refresh = mocker.MagicMock() - else: - token_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=fake_uri, signature=fake_signed_data, expiry=fake_expiry - ) - sastoken = st.NonRenewableSasToken(token_str) - return sastoken - - @pytest.mark.it("Cancels any existing token update alarm that may have been set") - def test_cancels_existing_alarm(self, mocker, mock_alarm, stage, op): - stage._token_update_alarm = mock_alarm - - stage.run_op(op) - - assert mock_alarm.cancel.call_count == 1 - assert mock_alarm.cancel.call_args == mocker.call() - - @pytest.mark.it("Resets the token update alarm to None until a new one is set") - # Edge case, since unless something goes wrong, the alarm WILL be set, and it's like - # it was never set to None. - def test_alarm_set_to_none_in_intermediate( - self, mocker, stage, op, mock_alarm, arbitrary_exception - ): - # Set an existing alarm - stage._token_update_alarm = mocker.MagicMock() - - # Set an error side effect on the alarm creation, so when a new alarm is created, - # we have an unhandled error causing op failure and early exit - mock_alarm.side_effect = arbitrary_exception - - stage.run_op(op) - - assert op.complete - assert op.error is arbitrary_exception - assert stage._token_update_alarm is None - - @pytest.mark.it( - "Starts a background update alarm that will trigger 'Update Margin' number of seconds prior to SasToken expiration" - ) - def test_sets_alarm(self, mocker, stage, op, mock_alarm): - expected_alarm_time = ( - stage.nucleus.pipeline_configuration.sastoken.expiry_time - - pipeline_stages_base.SasTokenStage.DEFAULT_TOKEN_UPDATE_MARGIN - ) - - stage.run_op(op) - - assert mock_alarm.call_count == 1 - assert mock_alarm.call_args[0][0] == expected_alarm_time - assert mock_alarm.return_value.daemon is True - assert mock_alarm.return_value.start.call_count == 1 - assert mock_alarm.return_value.start.call_args == mocker.call() - - @pytest.mark.it( - "Starts a background update alarm that will instead trigger after MAX_TIMEOUT seconds if the SasToken expiration time (less the Update Margin) is more than MAX_TIMEOUT seconds in the future" - ) - def test_sets_alarm_long_expiration(self, mocker, stage, op, mock_alarm): - token = stage.nucleus.pipeline_configuration.sastoken - new_expiry = token.expiry_time + threading.TIMEOUT_MAX - if isinstance(token, st.RenewableSasToken): - token._expiry_time = new_expiry - else: - token._token_info["se"] = new_expiry - # NOTE: time.time is implicitly mocked to return a constant test value here - expected_alarm_time = time.time() + threading.TIMEOUT_MAX - assert expected_alarm_time < token.expiry_time - - stage.run_op(op) - - assert mock_alarm.call_count == 1 - assert mock_alarm.call_args[0][0] == expected_alarm_time - assert mock_alarm.return_value.daemon is True - assert mock_alarm.return_value.start.call_count == 1 - assert mock_alarm.return_value.start.call_args == mocker.call() - - -@pytest.mark.describe( - "SasTokenStage - .run_op() -- Called with ReauthorizeConnectionOperation (Pipeline not configured for SAS authentication)" -) -class TestSasTokenStageRunOpWithReauthorizeConnectionOperationPipelineOpNoSasTokenConfig( - SasTokenStageTestConfig, StageRunOpTestBase -): - # NOTE: In practice this case will never happen. Currently ReauthorizeConnectionOperations only - # occur for SAS-based auth. Still, we test this combination of configurations for completeness - # and safety of having a defined behavior even for an impossible case, as we want to avoid - # using outside knowledge in unit-tests - without that knowledge of the rest of the client and - # pipeline, there's no reason to know that it couldn't happen. - - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) - - @pytest.fixture - def sastoken(self): - return None - - @pytest.mark.it("Sends the operation down, WITHOUT setting a update alarm") - def test_sends_op_down_no_alarm(self, mocker, stage, mock_alarm, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - assert stage._token_update_alarm is None - assert mock_alarm.call_count == 0 - - -@pytest.mark.describe("SasTokenStage - .run_op() -- Called with ShutdownPipelineOperation") -class TestSasTokenStageRunOpWithShutdownPipelineOp(SasTokenStageTestConfig, StageRunOpTestBase): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ShutdownPipelineOperation(callback=mocker.MagicMock()) - - @pytest.fixture( - params=[ - "Renewable SAS Authentication", - "Non-renewable SAS Authentication", - "No SAS Authentication", - ] - ) - def sastoken(self, mocker, request): - if request.param == "Renewable SAS Authentication": - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = fake_signed_data - sastoken = st.RenewableSasToken(uri=fake_uri, signing_mechanism=mock_signing_mechanism) - sastoken.refresh = mocker.MagicMock() - elif request.param == "Non-renewable SAS Authentication": - token_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=fake_uri, signature=fake_signed_data, expiry=fake_expiry - ) - sastoken = st.NonRenewableSasToken(token_str) - else: - sastoken = None - return sastoken - - @pytest.mark.it( - "Cancels the token update alarm and the reauth retry timer, then sends the operation down, if an alarm exists" - ) - def test_with_timer(self, mocker, stage, op, mock_alarm, mock_timer): - stage._token_update_alarm = mock_alarm - stage._reauth_retry_timer = mock_timer - assert mock_alarm.cancel.call_count == 0 - assert mock_timer.cancel.call_count == 0 - assert stage.send_op_down.call_count == 0 - - stage.run_op(op) - - assert mock_alarm.cancel.call_count == 1 - assert mock_timer.cancel.call_count == 1 - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it("Simply sends the operation down if no alarm or timer exists") - def test_no_timer(self, mocker, stage, op): - assert stage._token_update_alarm is None - assert stage._reauth_retry_timer is None - assert stage.send_op_down.call_count == 0 - - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "SasTokenStage - OCCURRENCE: SasToken Update Alarm expires (Renew Token - RenewableSasToken)" -) -class TestSasTokenStageOCCURRENCEUpdateAlarmExpiresRenewToken(SasTokenStageTestConfig): - @pytest.fixture - def init_op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.fixture - def sastoken(self, mocker): - # Renewable Token - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = fake_signed_data - sastoken = st.RenewableSasToken(uri=fake_uri, signing_mechanism=mock_signing_mechanism) - sastoken.refresh = mocker.MagicMock() - return sastoken - - @pytest.mark.it("Refreshes the pipeline's SasToken") - @pytest.mark.parametrize( - "connected", - [ - pytest.param(True, id="Pipeline connected"), - pytest.param(False, id="Pipeline not connected"), - ], - ) - def test_refresh_token( - self, mocker, stage, init_op, mock_alarm, connected, pipeline_connected_mock - ): - # Apply the alarm - stage.run_op(init_op) - - # Mock connected state - pipeline_connected_mock.return_value = connected - assert stage.nucleus.connected is connected - - # Token has not been refreshed - token = stage.nucleus.pipeline_configuration.sastoken - assert token.refresh.call_count == 0 - assert mock_alarm.call_count == 1 - - # Call alarm complete callback (as if alarm expired) - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # Token has now been refreshed - assert token.refresh.call_count == 1 - - @pytest.mark.it( - "Reports any SasTokenError that occurs while refreshing the SasToken as a background exception" - ) - @pytest.mark.parametrize( - "connected", - [ - pytest.param(True, id="Pipeline connected"), - pytest.param(False, id="Pipeline not connected"), - ], - ) - def test_refresh_token_fail( - self, mocker, stage, init_op, mock_alarm, connected, pipeline_connected_mock - ): - # Apply the alarm - stage.run_op(init_op) - - # Mock connected state - pipeline_connected_mock.return_value = connected - assert stage.nucleus.connected is connected - - # Mock refresh - token = stage.nucleus.pipeline_configuration.sastoken - refresh_failure = st.SasTokenError() - token.refresh = mocker.MagicMock(side_effect=refresh_failure) - assert token.refresh.call_count == 0 - assert stage.report_background_exception.call_count == 0 - - # Call alarm complete callback (as if alarm expired) - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - assert token.refresh.call_count == 1 - assert stage.report_background_exception.call_count == 1 - assert stage.report_background_exception.call_args == mocker.call(refresh_failure) - - @pytest.mark.it("Cancels any reauth retry timer that may exist") - @pytest.mark.parametrize( - "connected", - [ - pytest.param(True, id="Pipeline connected"), - pytest.param(False, id="Pipeline not connected"), - ], - ) - def test_cancels_reauth_retry( - self, mocker, stage, init_op, mock_alarm, connected, pipeline_connected_mock - ): - # Apply the alarm - stage.run_op(init_op) - assert mock_alarm.call_count == 1 - - # Mock connected state and timer - mock_timer = mocker.MagicMock() - stage._reauth_retry_timer = mock_timer - pipeline_connected_mock.return_value = connected - assert stage.nucleus.connected is connected - - # Call alarm complete callback (as if alarm expired) - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # The mock timer has been cancelled and unset - assert mock_timer.cancel.call_count == 1 - stage._reauth_retry_timer is None - - @pytest.mark.it( - "Sends a ReauthorizeConnectionOperation down the pipeline if the pipeline is in a 'connected' state" - ) - def test_when_pipeline_connected( - self, mocker, stage, init_op, mock_alarm, pipeline_connected_mock - ): - # Apply the alarm and mock pipeline as connected - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - stage.run_op(init_op) - - # Only the InitializePipeline init_op has been sent down - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(init_op) - - # Pipeline is still connected - assert stage.nucleus.connected - - # Call alarm complete callback (as if alarm expired) - assert mock_alarm.call_count == 1 - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # ReauthorizeConnectionOperation has now been sent down - assert stage.send_op_down.call_count == 2 - assert isinstance( - stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation - ) - - @pytest.mark.it( - "Does NOT send a ReauthorizeConnectionOperation down the pipeline if the pipeline is NOT in a 'connected' state" - ) - def test_when_pipeline_not_connected( - self, mocker, stage, init_op, mock_alarm, pipeline_connected_mock - ): - # Apply the alarm and mock pipeline as disconnected - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - stage.run_op(init_op) - - # Only the InitializePipeline init_op has been sent down - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(init_op) - - # Pipeline is still NOT connected - assert not stage.nucleus.connected - - # Call alarm complete callback (as if alarm expired) - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # No further ops have been sent down - assert stage.send_op_down.call_count == 1 - - @pytest.mark.it( - "Begins a new SasToken update alarm that will trigger 'Update Margin' number of seconds prior to the refreshed SasToken expiration" - ) - @pytest.mark.parametrize( - "connected", - [ - pytest.param(True, id="Pipeline connected"), - pytest.param(False, id="Pipeline not connected"), - ], - ) - # I am sorry for this test length, but IDK how else to test this... - # ... other than throwing everything at it at once - def test_new_alarm( - self, mocker, stage, init_op, mock_alarm, connected, pipeline_connected_mock - ): - token = stage.nucleus.pipeline_configuration.sastoken - - # Mock connected state - pipeline_connected_mock.return_value = connected - assert stage.nucleus.connected is connected - - # Apply the alarm - stage.run_op(init_op) - - # init_op was passed down - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(init_op) - - # Only one alarm has been created and started. No cancellation. - assert mock_alarm.call_count == 1 - assert mock_alarm.return_value.start.call_count == 1 - assert mock_alarm.return_value.cancel.call_count == 0 - - # Call alarm complete callback (as if alarm expired) - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # Existing alarm was cancelled - assert mock_alarm.return_value.cancel.call_count == 1 - - # Token was refreshed - assert token.refresh.call_count == 1 - - # Reauthorize was sent down (if the connection state was right) - if connected: - assert stage.send_op_down.call_count == 2 - assert isinstance( - stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation - ) - else: - assert stage.send_op_down.call_count == 1 - - # Another alarm was created and started for the expected time - assert mock_alarm.call_count == 2 - expected_alarm_time = ( - stage.nucleus.pipeline_configuration.sastoken.expiry_time - - pipeline_stages_base.SasTokenStage.DEFAULT_TOKEN_UPDATE_MARGIN - ) - assert mock_alarm.call_args[0][0] == expected_alarm_time - assert stage._token_update_alarm is mock_alarm.return_value - assert stage._token_update_alarm.daemon is True - assert stage._token_update_alarm.start.call_count == 2 - assert stage._token_update_alarm.start.call_args == mocker.call() - - # When THAT alarm expires, the token is refreshed, and the reauth is sent, etc. etc. etc. - # ... recursion :) - new_on_alarm_complete = mock_alarm.call_args[0][1] - new_on_alarm_complete() - - assert token.refresh.call_count == 2 - if connected: - assert stage.send_op_down.call_count == 3 - assert isinstance( - stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation - ) - else: - assert stage.send_op_down.call_count == 1 - - assert mock_alarm.call_count == 3 - # .... and on and on for infinity - - @pytest.mark.it( - "Begins a new SasToken update alarm that will instead trigger after MAX_TIMEOUT seconds if the refreshed SasToken expiration time (less the Update Margin) is more than MAX_TIMEOUT seconds in the future" - ) - @pytest.mark.parametrize( - "connected", - [ - pytest.param(True, id="Pipeline connected"), - pytest.param(False, id="Pipeline not connected"), - ], - ) - # I am sorry for this test length, but IDK how else to test this... - # ... other than throwing everything at it at once - def test_new_alarm_long_expiry( - self, mocker, stage, init_op, mock_alarm, connected, pipeline_connected_mock - ): - token = stage.nucleus.pipeline_configuration.sastoken - # Manually change the token TTL and expiry time to exceed max timeout - # Note that time.time() is implicitly mocked to return a constant value - token.ttl = threading.TIMEOUT_MAX + 3600 - token._expiry_time = int(time.time() + token.ttl) - - # Mock connected state - pipeline_connected_mock.return_value = connected - assert stage.nucleus.connected is connected - - # Apply the alarm - stage.run_op(init_op) - - # init_op was passed down - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(init_op) - - # Only one alarm has been created and started. No cancellation. - assert mock_alarm.call_count == 1 - assert mock_alarm.return_value.start.call_count == 1 - assert mock_alarm.return_value.cancel.call_count == 0 - - # Call alarm complete callback (as if alarm expired) - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # Existing alarm was cancelled - assert mock_alarm.return_value.cancel.call_count == 1 - - # Token was refreshed - assert token.refresh.call_count == 1 - - # Reauthorize was sent down (if the connection state was right) - if connected: - assert stage.send_op_down.call_count == 2 - assert isinstance( - stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation - ) - else: - assert stage.send_op_down.call_count == 1 - - # Another alarm was created and started for the expected time - assert mock_alarm.call_count == 2 - # NOTE: time.time() is implicitly mocked to return a constant test value here - expected_alarm_time = time.time() + threading.TIMEOUT_MAX - assert mock_alarm.call_args[0][0] == expected_alarm_time - assert stage._token_update_alarm is mock_alarm.return_value - assert stage._token_update_alarm.daemon is True - assert stage._token_update_alarm.start.call_count == 2 - assert stage._token_update_alarm.start.call_args == mocker.call() - - # When THAT alarm expires, the token is refreshed, and the reauth is sent, etc. etc. etc. - # ... recursion :) - new_on_alarm_complete = mock_alarm.call_args[0][1] - new_on_alarm_complete() - - assert token.refresh.call_count == 2 - if connected: - assert stage.send_op_down.call_count == 3 - assert isinstance( - stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation - ) - else: - assert stage.send_op_down.call_count == 1 - - assert mock_alarm.call_count == 3 - # .... and on and on for infinity - - -@pytest.mark.describe( - "SasTokenStage - OCCURRENCE: SasToken Update Alarm expires (Replace Token - NonRenewableSasToken)" -) -class TestSasTokenStageOCCURRENCEUpdateAlarmExpiresReplaceToken(SasTokenStageTestConfig): - @pytest.fixture - def init_op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.fixture - def sastoken(self, mocker): - # Non-Renewable Token - token_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=fake_uri, signature=fake_signed_data, expiry=fake_expiry - ) - sastoken = st.NonRenewableSasToken(token_str) - return sastoken - - @pytest.mark.it("Sends a NewSasTokenRequiredEvent up the pipeline") - @pytest.mark.parametrize( - "connected", - [ - pytest.param(True, id="Pipeline connected"), - pytest.param(False, id="Pipeline not connected"), - ], - ) - def test_sends_event(self, stage, init_op, mock_alarm, connected, pipeline_connected_mock): - # Mock connected state - pipeline_connected_mock.return_value = connected - assert stage.nucleus.connected is connected - # Apply the alarm - stage.run_op(init_op) - # Alarm was created - assert mock_alarm.call_count == 1 - # No events have been sent up the pipeline - assert stage.send_event_up.call_count == 0 - - # Call alarm complete callback (as if alarm expired) - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # Event was sent up - assert stage.send_event_up.call_count == 1 - assert isinstance( - stage.send_event_up.call_args[0][0], pipeline_events_base.NewSasTokenRequiredEvent - ) - - -# NOTE: base tests for reauth fail suites. Reauth can be generated by two different conditions -# but need separate test classes for them, even though the tests themselves are the same -class SasTokenStageOCCURRENCEReauthorizeConnectionOperationFailsTests(SasTokenStageTestConfig): - @pytest.fixture - def sastoken(self, mocker): - # Renewable Token - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = fake_signed_data - sastoken = st.RenewableSasToken(uri=fake_uri, signing_mechanism=mock_signing_mechanism) - sastoken.refresh = mocker.MagicMock() - return sastoken - - # NOTE: you must implement a "reauth_op" fixture in subclass for these tests to run - - @pytest.mark.it("Reports a background exception") - @pytest.mark.parametrize( - "connected", - [ - pytest.param(True, id="Pipeline Connected"), # NOTE: this probably would never happen - pytest.param(False, id="Pipeline Disconnected"), - ], - ) - @pytest.mark.parametrize( - "connection_retry", - [ - pytest.param(True, id="Connection Retry Enabled"), - pytest.param(False, id="Connection Retry Disabled"), - ], - ) - def test_reports_background_exception( - self, - mocker, - stage, - reauth_op, - arbitrary_exception, - connected, - connection_retry, - pipeline_connected_mock, - ): - assert stage.report_background_exception.call_count == 0 - - # Mock the connection state and set the retry feature - pipeline_connected_mock.return_value = connected - assert stage.nucleus.connected is connected - stage.nucleus.pipeline_configuration.connection_retry = connection_retry - - # Complete ReauthorizeConnectionOperation with error - reauth_op.complete(error=arbitrary_exception) - - # Error was sent to background handler - assert stage.report_background_exception.call_count == 1 - assert stage.report_background_exception.call_args == mocker.call(arbitrary_exception) - - @pytest.mark.it( - "Starts a reauth retry timer for the connection retry interval if the pipeline is not connected and connection retry is enabled on the pipeline" - ) - def test_starts_retry_timer( - self, mocker, stage, reauth_op, arbitrary_exception, mock_timer, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - stage.nucleus.pipeline_configuration.connection_retry = True - - assert mock_timer.call_count == 0 - - reauth_op.complete(error=arbitrary_exception) - - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call( - stage.nucleus.pipeline_configuration.connection_retry_interval, mocker.ANY - ) - assert mock_timer.return_value.start.call_count == 1 - assert mock_timer.return_value.start.call_args == mocker.call() - assert mock_timer.return_value.daemon is True - - -@pytest.mark.describe( - "SasTokenStage - OCCURRENCE: ReauthorizeConnectionOperation sent by SasToken Update Alarm fails" -) -class TestSasTokenStageOCCURRENCEReauthorizeConnectionOperationFromAlarmFails( - SasTokenStageOCCURRENCEReauthorizeConnectionOperationFailsTests -): - @pytest.fixture - def reauth_op(self, mocker, stage, mock_alarm, pipeline_connected_mock): - # Initialize the pipeline - pipeline_connected_mock.return_value = True - init_op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - stage.run_op(init_op) - - # Call alarm complete callback (as if alarm expired) - assert mock_alarm.call_count == 1 - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # ReauthorizeConnectionOperation has now been sent down - assert stage.send_op_down.call_count == 2 - reauth_op = stage.send_op_down.call_args[0][0] - assert isinstance(reauth_op, pipeline_ops_base.ReauthorizeConnectionOperation) - - # Reset mocks - mock_alarm.reset_mock() - return reauth_op - - -@pytest.mark.describe("SasTokenStage - OCCURRENCE: Reauth Retry Timer expires") -class TestSasTokenStageOCCURRENCEReauthRetryTimerExpires(SasTokenStageTestConfig): - @pytest.fixture - def init_op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.fixture - def sastoken(self, mocker): - # Renewable Token - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = fake_signed_data - sastoken = st.RenewableSasToken(uri=fake_uri, signing_mechanism=mock_signing_mechanism) - sastoken.refresh = mocker.MagicMock() - return sastoken - - @pytest.mark.it( - "Sends a ReauthorizeConnectionOperation down the pipeline if the pipeline is still not connected" - ) - def test_while_disconnected( - self, - mocker, - stage, - init_op, - mock_alarm, - mock_timer, - arbitrary_exception, - pipeline_connected_mock, - ): - # Initialize stage with alarm - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - stage.nucleus.pipeline_configuration.connection_retry = True - stage.run_op(init_op) - - # Only the InitializePipeline op has been sent down - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(init_op) - - # Pipeline is still connected - assert stage.nucleus.connected - - # Call alarm complete callback (as if alarm expired) - assert mock_alarm.call_count == 1 - assert stage._token_update_alarm is mock_alarm.return_value - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # First ReauthorizeConnectionOperation has now been sent down - assert stage.send_op_down.call_count == 2 - reauth_op = stage.send_op_down.call_args[0][0] - assert isinstance(reauth_op, pipeline_ops_base.ReauthorizeConnectionOperation) - - # Complete the ReauthorizeConnectionOperation with failure, triggering retry - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - reauth_op.complete(error=arbitrary_exception) - - # Call timer complete callback (as if timer expired) - assert mock_timer.call_count == 1 - assert stage._reauth_retry_timer is mock_timer.return_value - assert not stage.nucleus.connected - on_timer_complete = mock_timer.call_args[0][1] - on_timer_complete() - - # ReauthorizeConnectionOperation has now been sent down - assert stage.send_op_down.call_count == 3 - reauth_op = stage.send_op_down.call_args[0][0] - assert isinstance(reauth_op, pipeline_ops_base.ReauthorizeConnectionOperation) - - @pytest.mark.it( - "Does not send a ReauthorizeConnectionOperation if the pipeline is now connected" - ) - def test_while_connected( - self, - mocker, - stage, - init_op, - mock_alarm, - mock_timer, - arbitrary_exception, - pipeline_connected_mock, - ): - # Initialize stage with alarm - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - stage.nucleus.pipeline_configuration.connection_retry = True - stage.run_op(init_op) - - # Only the InitializePipeline op has been sent down - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(init_op) - - # Pipeline is still connected - assert stage.nucleus.connected - - # Call alarm complete callback (as if alarm expired) - assert mock_alarm.call_count == 1 - assert stage._token_update_alarm is mock_alarm.return_value - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # First ReauthorizeConnectionOperation has now been sent down - assert stage.send_op_down.call_count == 2 - reauth_op = stage.send_op_down.call_args[0][0] - assert isinstance(reauth_op, pipeline_ops_base.ReauthorizeConnectionOperation) - - # Complete the ReauthorizeConnectionOperation with failure, triggering retry - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - reauth_op.complete(error=arbitrary_exception) - - # Call timer complete callback (as if timer expired) - assert mock_timer.call_count == 1 - assert stage._reauth_retry_timer is mock_timer.return_value - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected # Re-establish before timer completes - on_timer_complete = mock_timer.call_args[0][1] - on_timer_complete() - - # Nothing else been sent down - assert stage.send_op_down.call_count == 2 - - -@pytest.mark.describe( - "SasTokenStage - OCCURRENCE: ReauthorizeConnectionOperation sent by Reauth Retry Timer fails" -) -class TestSasTokenStageOCCURRENCEReauthorizeConnectionOperationFromTimerFails( - SasTokenStageOCCURRENCEReauthorizeConnectionOperationFailsTests -): - @pytest.fixture - def reauth_op( - self, mocker, stage, mock_alarm, mock_timer, arbitrary_exception, pipeline_connected_mock - ): - # Initialize the pipeline - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - init_op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - stage.run_op(init_op) - - # Call alarm complete callback (as if alarm expired) - assert mock_alarm.call_count == 1 - on_alarm_complete = mock_alarm.call_args[0][1] - on_alarm_complete() - - # ReauthorizeConnectionOperation has now been sent down - assert stage.send_op_down.call_count == 2 - reauth_op = stage.send_op_down.call_args[0][0] - assert isinstance(reauth_op, pipeline_ops_base.ReauthorizeConnectionOperation) - - # Complete the ReauthorizeConnectionOperation with failure, triggering retry - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - reauth_op.complete(error=arbitrary_exception) - - # Call timer complete callback (as if timer expired) - assert mock_timer.call_count == 1 - assert stage._reauth_retry_timer is mock_timer.return_value - assert not stage.nucleus.connected - assert stage.report_background_exception.call_count == 1 - on_timer_complete = mock_timer.call_args[0][1] - on_timer_complete() - - # ReauthorizeConnectionOperation has now been sent down - assert stage.send_op_down.call_count == 3 - reauth_op = stage.send_op_down.call_args[0][0] - assert isinstance(reauth_op, pipeline_ops_base.ReauthorizeConnectionOperation) - - # Reset mocks - mock_timer.reset_mock() - mock_alarm.reset_mock() - stage.report_background_exception.reset_mock() - return reauth_op - - -###################### -# AUTO CONNECT STAGE # -###################### - - -class AutoConnectStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_base.AutoConnectStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def pl_config(self, mocker): - pl_cfg = mocker.MagicMock() - pl_cfg.auto_connect = True - return pl_cfg - - @pytest.fixture - def stage(self, mocker, nucleus, pl_config, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration = pl_config - # Mock flow methods - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_base.AutoConnectStage, - stage_test_config_class=AutoConnectStageTestConfig, -) - - -@pytest.mark.describe( - "AutoConnectStage - .run_op() -- Called with an Operation that requires an active connection (pipeline already connected)" -) -class TestAutoConnectStageRunOpWithOpThatRequiresConnectionPipelineConnected( - AutoConnectStageTestConfig, StageRunOpTestBase -): - - fake_topic = "__fake_topic__" - fake_payload = "__fake_payload__" - - ops_requiring_connection = [ - pipeline_ops_mqtt.MQTTPublishOperation, - pipeline_ops_mqtt.MQTTSubscribeOperation, - pipeline_ops_mqtt.MQTTUnsubscribeOperation, - ] - - @pytest.fixture(params=ops_requiring_connection) - def op(self, mocker, request): - op_class = request.param - if op_class is pipeline_ops_mqtt.MQTTPublishOperation: - op = op_class( - topic=self.fake_topic, payload=self.fake_payload, callback=mocker.MagicMock() - ) - else: - op = op_class(topic=self.fake_topic, callback=mocker.MagicMock()) - assert op.needs_connection - return op - - @pytest.mark.it("Immediately sends the operation down the pipeline") - def test_already_connected(self, mocker, stage, op, pipeline_connected_mock): - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "AutoConnectStage - .run_op() -- Called with an Operation that requires an active connection (pipeline not connected)" -) -class TestAutoConnectStageRunOpWithOpThatRequiresConnectionNotConnected( - AutoConnectStageTestConfig, StageRunOpTestBase -): - - fake_topic = "__fake_topic__" - fake_payload = "__fake_payload__" - - ops_requiring_connection = [ - pipeline_ops_mqtt.MQTTPublishOperation, - pipeline_ops_mqtt.MQTTSubscribeOperation, - pipeline_ops_mqtt.MQTTUnsubscribeOperation, - ] - - @pytest.fixture(params=ops_requiring_connection) - def op(self, mocker, request): - op_class = request.param - if op_class is pipeline_ops_mqtt.MQTTPublishOperation: - op = op_class( - topic=self.fake_topic, payload=self.fake_payload, callback=mocker.MagicMock() - ) - else: - op = op_class(topic=self.fake_topic, callback=mocker.MagicMock()) - assert op.needs_connection - return op - - @pytest.mark.it("Sends a new ConnectOperation down the pipeline") - def test_not_connected(self, mocker, stage, op, pipeline_connected_mock): - mock_connect_op = mocker.patch.object(pipeline_ops_base, "ConnectOperation").return_value - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(mock_connect_op) - - @pytest.mark.it( - "Sends the operation down the pipeline once the ConnectOperation completes successfully" - ) - def test_connect_success(self, mocker, stage, op, pipeline_connected_mock): - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - mocker.spy(stage, "run_op") - - # Run the original operation - stage.run_op(op) - assert not op.completed - - # Complete the newly created ConnectOperation that was sent down the pipeline - assert stage.send_op_down.call_count == 1 - connect_op = stage.send_op_down.call_args[0][0] - assert isinstance(connect_op, pipeline_ops_base.ConnectOperation) - assert not connect_op.completed - connect_op.complete() # no error - - # The original operation has now been sent down the pipeline - assert stage.run_op.call_count == 2 - assert stage.run_op.call_args == mocker.call(op) - - @pytest.mark.it( - "Completes the operation with the error from the ConnectOperation, if the ConnectOperation completes with an error" - ) - def test_connect_failure(self, mocker, stage, op, arbitrary_exception, pipeline_connected_mock): - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - - # Run the original operation - stage.run_op(op) - assert not op.completed - - # Complete the newly created ConnectOperation that was sent down the pipeline - assert stage.send_op_down.call_count == 1 - connect_op = stage.send_op_down.call_args[0][0] - assert isinstance(connect_op, pipeline_ops_base.ConnectOperation) - assert not connect_op.completed - connect_op.complete(error=arbitrary_exception) # completes with error - - # The original operation has been completed the exception from the ConnectOperation - assert op.completed - assert op.error is arbitrary_exception - - -@pytest.mark.describe( - "AutoConnectStage - .run_op() -- Called with an Operation that does not require an active connection" -) -class TestAutoConnectStageRunOpWithOpThatDoesNotRequireConnection( - AutoConnectStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, arbitrary_op): - assert not arbitrary_op.needs_connection - return arbitrary_op - - @pytest.mark.it( - "Sends the operation down the pipeline if the pipeline is in a 'connected' state" - ) - def test_connected(self, mocker, stage, op, pipeline_connected_mock): - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it( - "Sends the operation down the pipeline if the pipeline is in a 'disconnected' state" - ) - def test_disconnected(self, mocker, stage, op, pipeline_connected_mock): - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "AutoConnectStage - .run_op() -- Called while pipeline configured to disable Auto Connect" -) -class TestAutoConnectStageRunOpWithAutoConnectDisabled( - AutoConnectStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def pl_config(self, mocker): - pl_cfg = mocker.MagicMock() - pl_cfg.auto_connect = False - return pl_cfg - - @pytest.fixture(params=["Op requires connection", "Op does NOT require connection"]) - def op(self, request, arbitrary_op): - if request.param == "Op requires connection": - arbitrary_op.needs_connection = True - else: - arbitrary_op.needs_connection = False - return arbitrary_op - - @pytest.mark.it( - "Sends the operation down the pipeline if the pipeline is in a 'connected' state" - ) - def test_connected(self, mocker, stage, op, pipeline_connected_mock): - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it( - "Sends the operation down the pipeline if the pipeline is in a 'disconnected' state" - ) - def test_disconnected(self, mocker, stage, op, pipeline_connected_mock): - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -######################################### -# COORDINATE REQUEST AND RESPONSE STAGE # -######################################### - - -@pytest.fixture -def fake_uuid(mocker): - my_uuid = "0f4f876b-f445-432e-a8de-43bbd66e4668" - uuid4_mock = mocker.patch.object(uuid, "uuid4") - uuid4_mock.return_value.__str__.return_value = my_uuid - return my_uuid - - -class CoordinateRequestAndResponseStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_base.CoordinateRequestAndResponseStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class CoordinateRequestAndResponseStageInstantiationTests( - CoordinateRequestAndResponseStageTestConfig -): - @pytest.mark.it("Initializes 'pending_responses' as an empty dict") - def test_pending_responses(self, init_kwargs): - stage = pipeline_stages_base.CoordinateRequestAndResponseStage(**init_kwargs) - assert stage.pending_responses == {} - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_base.CoordinateRequestAndResponseStage, - stage_test_config_class=CoordinateRequestAndResponseStageTestConfig, - extended_stage_instantiation_test_class=CoordinateRequestAndResponseStageInstantiationTests, -) - - -@pytest.mark.describe( - "CoordinateRequestAndResponseStage - .run_op() -- Called with a RequestAndResponseOperation" -) -class TestCoordinateRequestAndResponseStageRunOpWithRequestAndResponseOperation( - CoordinateRequestAndResponseStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.RequestAndResponseOperation( - request_type="some_request_type", - method="SOME_METHOD", - resource_location="some/resource/location", - request_body="some_request_body", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it( - "Stores the operation in the 'pending_responses' dictionary, mapped with a generated UUID" - ) - def test_stores_op(self, mocker, stage, op, fake_uuid): - stage.run_op(op) - - assert stage.pending_responses[fake_uuid] is op - assert not op.completed - - @pytest.mark.it( - "Creates and a new RequestOperation using the generated UUID and sends it down the pipeline" - ) - def test_sends_down_new_request_op(self, mocker, stage, op, fake_uuid): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - request_op = stage.send_op_down.call_args[0][0] - assert isinstance(request_op, pipeline_ops_base.RequestOperation) - assert request_op.method == op.method - assert request_op.resource_location == op.resource_location - assert request_op.request_body == op.request_body - assert request_op.request_type == op.request_type - assert request_op.request_id == fake_uuid - - @pytest.mark.it( - "Generates a unique UUID for each RequestAndResponseOperation/RequestOperation pair" - ) - def test_unique_uuid(self, stage, op): - op1 = op - op2 = copy.deepcopy(op) - op3 = copy.deepcopy(op) - - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - uuid1 = stage.send_op_down.call_args[0][0].request_id - stage.run_op(op2) - assert stage.send_op_down.call_count == 2 - uuid2 = stage.send_op_down.call_args[0][0].request_id - stage.run_op(op3) - assert stage.send_op_down.call_count == 3 - uuid3 = stage.send_op_down.call_args[0][0].request_id - - assert uuid1 != uuid2 != uuid3 - assert stage.pending_responses[uuid1] is op1 - assert stage.pending_responses[uuid2] is op2 - assert stage.pending_responses[uuid3] is op3 - - -@pytest.mark.describe( - "CoordinateRequestAndResponseStage - .run_op() -- Called with an arbitrary other operation" -) -class TestCoordinateRequestAndResponseStageRunOpWithArbitraryOperation( - CoordinateRequestAndResponseStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_down(self, stage, mocker, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "CoordinateRequestAndResponseStage - OCCURRENCE: RequestOperation tied to a stored RequestAndResponseOperation is completed" -) -class TestCoordinateRequestAndResponseStageRequestOperationCompleted( - CoordinateRequestAndResponseStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.RequestAndResponseOperation( - request_type="some_request_type", - method="SOME_METHOD", - resource_location="some/resource/location", - request_body="some_request_body", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it( - "Completes the associated RequestAndResponseOperation with the error from the RequestOperation and removes it from the 'pending_responses' dict, if the RequestOperation is completed unsuccessfully" - ) - def test_request_completed_with_error(self, stage, op, arbitrary_exception): - stage.run_op(op) - request_op = stage.send_op_down.call_args[0][0] - - assert not op.completed - assert not request_op.completed - assert stage.pending_responses[request_op.request_id] is op - - request_op.complete(error=arbitrary_exception) - - # RequestAndResponseOperation has been completed with the error from the RequestOperation - assert request_op.completed - assert op.completed - assert op.error is request_op.error is arbitrary_exception - - # RequestAndResponseOperation has been removed from the 'pending_responses' dict - with pytest.raises(KeyError): - stage.pending_responses[request_op.request_id] - - @pytest.mark.it( - "Does not complete or remove the RequestAndResponseOperation from the 'pending_responses' dict if the RequestOperation is completed successfully" - ) - def test_request_completed_successfully(self, stage, op): - stage.run_op(op) - request_op = stage.send_op_down.call_args[0][0] - - request_op.complete() - - assert request_op.completed - assert not op.completed - assert stage.pending_responses[request_op.request_id] is op - - @pytest.mark.it( - "Does not remove a no-longer existing RequestAndResponseOperation from the 'pending_responses' dict, if the RequestOperation is completed unsuccessfully" - ) - def test_deleted_request_completed_unsuccessfully(self, stage, op, arbitrary_exception): - stage.run_op(op) - request_op = stage.send_op_down.call_args[0][0] - - assert stage.pending_responses[request_op.request_id] is op - - # Complete and remove the RequestAndResponseOperation - op.complete() - del stage.pending_responses[request_op.request_id] - - assert request_op.request_id not in stage.pending_responses - - # Complete the RequestOperation - request_op.complete(error=arbitrary_exception) - # There are no further assertions because, if this does not raise an error, - # the test is successful - - -@pytest.mark.describe( - "CoordinateRequestAndResponseStage - .handle_pipeline_event() -- Called with ResponseEvent" -) -class TestCoordinateRequestAndResponseStageHandlePipelineEventWithResponseEvent( - CoordinateRequestAndResponseStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self, fake_uuid): - return pipeline_events_base.ResponseEvent( - request_id=fake_uuid, status_code=200, response_body="response body" - ) - - @pytest.fixture - def pending_op(self, mocker): - return pipeline_ops_base.RequestAndResponseOperation( - request_type="some_request_type", - method="SOME_METHOD", - resource_location="some/resource/location", - request_body="some_request_body", - callback=mocker.MagicMock(), - ) - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, fake_uuid, nucleus, pending_op): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.send_event_up = mocker.MagicMock() - stage.send_op_down = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - - # Run the pending op - stage.run_op(pending_op) - return stage - - @pytest.mark.it( - "Successfully completes a pending RequestAndResponseOperation that matches the 'request_id' of the ResponseEvent, and removes it from the 'pending_responses' dictionary" - ) - def test_completes_matching_request_and_response_operation( - self, stage, pending_op, event, fake_uuid - ): - assert stage.pending_responses[fake_uuid] is pending_op - assert not pending_op.completed - - # Handle the ResponseEvent - assert event.request_id == fake_uuid - stage.handle_pipeline_event(event) - - # The pending RequestAndResponseOperation is complete - assert pending_op.completed - - # The RequestAndResponseOperation has been removed from the dictionary - with pytest.raises(KeyError): - stage.pending_responses[fake_uuid] - - @pytest.mark.it( - "Sets the 'status_code' and 'response_body' attributes on the completed RequestAndResponseOperation with values from the ResponseEvent" - ) - def test_returns_values_in_attributes(self, mocker, stage, pending_op, event): - assert not pending_op.completed - assert pending_op.status_code is None - assert pending_op.response_body is None - - stage.handle_pipeline_event(event) - - assert pending_op.completed - assert pending_op.status_code == event.status_code - assert pending_op.response_body == event.response_body - - @pytest.mark.it( - "Does nothing if there is no pending RequestAndResponseOperation that matches the 'request_id' of the ResponseEvent" - ) - def test_no_matching_request_id(self, mocker, stage, pending_op, event, fake_uuid): - assert stage.pending_responses[fake_uuid] is pending_op - assert not pending_op.completed - - # Use a non-matching UUID - event.request_id = "non-matching-uuid" - assert event.request_id != fake_uuid - stage.handle_pipeline_event(event) - - # Nothing has changed - assert stage.pending_responses[fake_uuid] is pending_op - assert not pending_op.completed - - -@pytest.mark.describe( - "CoordinateRequestAndResponseStage - .handle_pipeline_event() -- Called with ConnectedEvent" -) -class TestCoordinateRequestAndResponseStageHandlePipelineEventWithConnectedEvent( - CoordinateRequestAndResponseStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self): - return pipeline_events_base.ConnectedEvent() - - def make_new_request_response_op(self, mocker, id): - return pipeline_ops_base.RequestAndResponseOperation( - request_type="some_request_type", - method="SOME_METHOD", - resource_location="some/resource/location/{}".format(id), - request_body="some_request_body", - callback=mocker.MagicMock(), - ) - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.send_event_up = mocker.MagicMock() - stage.send_op_down = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - - return stage - - @pytest.mark.it("Sends a RequestOperation down again if that RequestOperation never completed") - def test_request_never_completed(self, stage, event, mocker): - op1 = self.make_new_request_response_op(mocker, "op1") - - # send it down but don't complete it - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - op1_guid = stage.send_op_down.call_args[0][0].request_id - - # simulate a connected event - stage.handle_pipeline_event(event) - - # verify that it got sent again - assert stage.send_op_down.call_count == 2 - assert isinstance(stage.send_op_down.call_args[0][0], pipeline_ops_base.RequestOperation) - assert stage.send_op_down.call_args[0][0].request_id == op1_guid - - @pytest.mark.it( - "Sends a RequestOperation down again if that RequestOperation completed, but no corresponding ResponseEvent was received" - ) - def test_response_never_received(self, stage, event, mocker): - op1 = self.make_new_request_response_op(mocker, "op1") - - # send it down and completed the RequestOperation - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - op1_guid = stage.send_op_down.call_args[0][0].request_id - stage.send_op_down.call_args[0][0].complete() - - # simulate a connected event - stage.handle_pipeline_event(event) - - # verify that it got sent again - assert stage.send_op_down.call_count == 2 - assert isinstance(stage.send_op_down.call_args[0][0], pipeline_ops_base.RequestOperation) - assert stage.send_op_down.call_args[0][0].request_id == op1_guid - - @pytest.mark.it( - "Sends down multiple RequestOperations again if those RequestOperations never completed" - ) - def test_multiple_requests_never_completed(self, stage, event, mocker): - op1 = self.make_new_request_response_op(mocker, "op1") - op2 = self.make_new_request_response_op(mocker, "op2") - - # send 2 ops down but don't complete them - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - op1_guid = stage.send_op_down.call_args[0][0].request_id - - stage.run_op(op2) - assert stage.send_op_down.call_count == 2 - op2_guid = stage.send_op_down.call_args[0][0].request_id - - # simulate a connected event - stage.handle_pipeline_event(event) - - # verify that 2 more RequestOperation ops were send down - assert stage.send_op_down.call_count == 4 - assert isinstance( - stage.send_op_down.call_args_list[0][0][0], pipeline_ops_base.RequestOperation - ) - assert isinstance( - stage.send_op_down.call_args_list[1][0][0], pipeline_ops_base.RequestOperation - ) - - assert ( - stage.send_op_down.call_args_list[2][0][0].request_id == op1_guid - and stage.send_op_down.call_args_list[3][0][0].request_id == op2_guid - ) or ( - stage.send_op_down.call_args_list[2][0][0].request_id == op2_guid - and stage.send_op_down.call_args_list[3][0][0].request_id == op1_guid - ) - - @pytest.mark.it( - "Sends down multiple RequestOperations again if those RequestOperations completed, but the corresponding ResponseEvents were never received" - ) - def test_multiple_responses_never_received(self, stage, event, mocker): - op1 = self.make_new_request_response_op(mocker, "op1") - op2 = self.make_new_request_response_op(mocker, "op2") - - # send 2 ops down - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - op1_guid = stage.send_op_down.call_args[0][0].request_id - - stage.run_op(op2) - assert stage.send_op_down.call_count == 2 - op2_guid = stage.send_op_down.call_args[0][0].request_id - - # complete the 2 RequestOperation ops - stage.send_op_down.call_arg_list[0][0][0].complete() - stage.send_op_down.call_arg_list[1][0][0].complete() - - # simulate a connected event - stage.handle_pipeline_event(event) - - # verify that 2 more RequestOperation ops were send down - assert stage.send_op_down.call_count == 4 - assert isinstance( - stage.send_op_down.call_args_list[0][0][0], pipeline_ops_base.RequestOperation - ) - assert isinstance( - stage.send_op_down.call_args_list[1][0][0], pipeline_ops_base.RequestOperation - ) - - assert ( - stage.send_op_down.call_args_list[2][0][0].request_id == op1_guid - and stage.send_op_down.call_args_list[3][0][0].request_id == op2_guid - ) or ( - stage.send_op_down.call_args_list[2][0][0].request_id == op2_guid - and stage.send_op_down.call_args_list[3][0][0].request_id == op1_guid - ) - - @pytest.mark.it( - "Does not send down any RequestOperations if the RequestAndResponseOperation completed" - ) - def test_request_and_response_completed(self, stage, event, mocker): - op1 = self.make_new_request_response_op(mocker, "op1") - - # send it down and complete the RequestOperation op - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - op1_guid = stage.send_op_down.call_args[0][0].request_id - stage.send_op_down.call_args[0][0].complete() - - # simulate the corresponding response - response_event = pipeline_events_base.ResponseEvent( - request_id=op1_guid, status_code=200, response_body="response body" - ) - stage.handle_pipeline_event(response_event) - - # verify that the op is complete - assert op1.completed - - # simulate a connected event - stage.handle_pipeline_event(event) - - # verify that nothing else was sent - assert stage.send_op_down.call_count == 1 - - @pytest.mark.it("Can independently track and resend multiple RequestOperations") - def test_one_completed_one_outstanding(self, stage, event, mocker): - op1 = self.make_new_request_response_op(mocker, "op1") - op2 = self.make_new_request_response_op(mocker, "op2") - - # send 2 ops down - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - op1_guid = stage.send_op_down.call_args[0][0].request_id - - stage.run_op(op2) - assert stage.send_op_down.call_count == 2 - op2_guid = stage.send_op_down.call_args[0][0].request_id - - # complete the 2 RequestOperation ops - stage.send_op_down.call_arg_list[0][0][0].complete() - stage.send_op_down.call_arg_list[1][0][0].complete() - - # simulate a response for the first RequestOperation - response_event = pipeline_events_base.ResponseEvent( - request_id=op1_guid, status_code=200, response_body="response body" - ) - stage.handle_pipeline_event(response_event) - - # simulate a connected event - stage.handle_pipeline_event(event) - - # verify that the re-sent RequestOperation was sent down for the incomplete RequestAndResponseOperation - assert stage.send_op_down.call_count == 3 - assert isinstance(stage.send_op_down.call_args[0][0], pipeline_ops_base.RequestOperation) - assert stage.send_op_down.call_args[0][0].request_id == op2_guid - - @pytest.mark.it( - "Does not send down any RequestOperations if all those RequestAndResponseOperations are complete" - ) - def test_all_completed(self, stage, event, mocker): - op1 = self.make_new_request_response_op(mocker, "op1") - op2 = self.make_new_request_response_op(mocker, "op2") - - # send 2 ops down - stage.run_op(op1) - assert stage.send_op_down.call_count == 1 - op1_guid = stage.send_op_down.call_args[0][0].request_id - - stage.run_op(op2) - assert stage.send_op_down.call_count == 2 - op2_guid = stage.send_op_down.call_args[0][0].request_id - - # complete the 2 RequestOperation ops - stage.send_op_down.call_arg_list[0][0][0].complete() - stage.send_op_down.call_arg_list[1][0][0].complete() - - # simulate 2 responses - stage.handle_pipeline_event( - pipeline_events_base.ResponseEvent( - request_id=op1_guid, status_code=200, response_body="response body" - ) - ) - stage.handle_pipeline_event( - pipeline_events_base.ResponseEvent( - request_id=op2_guid, status_code=200, response_body="response body" - ) - ) - - # simulate a connected event - stage.handle_pipeline_event(event) - - # verify that nothing else was sent down - assert stage.send_op_down.call_count == 2 - - -@pytest.mark.describe( - "CoordinateRequestAndResponseStage - .handle_pipeline_event() -- Called with arbitrary other event" -) -class TestCoordinateRequestAndResponseStageHandlePipelineEventWithArbitraryEvent( - CoordinateRequestAndResponseStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it("Sends the event up the pipeline") - def test_sends_up(self, mocker, stage, event): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -#################### -# OP TIMEOUT STAGE # -#################### - -ops_that_time_out = [ - pipeline_ops_mqtt.MQTTSubscribeOperation, - pipeline_ops_mqtt.MQTTUnsubscribeOperation, -] - - -class OpTimeoutStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_base.OpTimeoutStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class OpTimeoutStageInstantiationTests(OpTimeoutStageTestConfig): - # NOTE: this will no longer be necessary once these are implemented as part of a more robust retry policy - @pytest.mark.it( - "Sets default timeout intervals to 10 seconds for MQTTSubscribeOperation and MQTTUnsubscribeOperation" - ) - def test_timeout_intervals(self, init_kwargs): - stage = pipeline_stages_base.OpTimeoutStage(**init_kwargs) - assert stage.timeout_intervals[pipeline_ops_mqtt.MQTTSubscribeOperation] == 10 - assert stage.timeout_intervals[pipeline_ops_mqtt.MQTTUnsubscribeOperation] == 10 - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_base.OpTimeoutStage, - stage_test_config_class=OpTimeoutStageTestConfig, - extended_stage_instantiation_test_class=OpTimeoutStageInstantiationTests, -) - - -@pytest.mark.describe("OpTimeoutStage - .run_op() -- Called with operation eligible for timeout") -class TestOpTimeoutStageRunOpCalledWithOpThatCanTimeout( - OpTimeoutStageTestConfig, StageRunOpTestBase -): - @pytest.fixture(params=ops_that_time_out) - def op(self, mocker, request): - op_cls = request.param - op = op_cls(topic="some/topic", callback=mocker.MagicMock()) - return op - - @pytest.mark.it( - "Adds a timeout timer with the interval specified in the configuration to the operation, and starts it" - ) - def test_adds_timer(self, mocker, stage, op, mock_timer): - - stage.run_op(op) - - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call(stage.timeout_intervals[type(op)], mocker.ANY) - assert op.timeout_timer is mock_timer.return_value - assert op.timeout_timer.start.call_count == 1 - assert op.timeout_timer.start.call_args == mocker.call() - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_down(self, mocker, stage, op, mock_timer): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - assert op.timeout_timer is mock_timer.return_value - - -@pytest.mark.describe( - "OpTimeoutStage - .run_op() -- Called with arbitrary operation that is not eligible for timeout" -) -class TestOpTimeoutStageRunOpCalledWithOpThatDoesNotTimeout( - OpTimeoutStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline without attaching a timeout timer") - def test_sends_down(self, mocker, stage, op, mock_timer): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - assert mock_timer.call_count == 0 - assert not hasattr(op, "timeout_timer") - - -@pytest.mark.describe( - "OpTimeoutStage - OCCURRENCE: Operation with a timeout timer times out before completion" -) -class TestOpTimeoutStageOpTimesOut(OpTimeoutStageTestConfig): - @pytest.fixture(params=ops_that_time_out) - def op(self, mocker, request): - op_cls = request.param - op = op_cls(topic="some/topic", callback=mocker.MagicMock()) - return op - - @pytest.mark.it("Completes the operation unsuccessfully, with a PipelineTimeoutError") - def test_pipeline_timeout(self, mocker, stage, op, mock_timer): - # Apply the timer - stage.run_op(op) - assert not op.completed - assert mock_timer.call_count == 1 - on_timer_complete = mock_timer.call_args[0][1] - - # Call timer complete callback (indicating timer completion) - on_timer_complete() - - # Op is now completed with error - assert op.completed - assert isinstance(op.error, pipeline_exceptions.OperationTimeout) - - -@pytest.mark.describe( - "OpTimeoutStage - OCCURRENCE: Operation with a timeout timer completes before timeout" -) -class TestOpTimeoutStageOpCompletesBeforeTimeout(OpTimeoutStageTestConfig): - @pytest.fixture(params=ops_that_time_out) - def op(self, mocker, request): - op_cls = request.param - op = op_cls(topic="some/topic", callback=mocker.MagicMock()) - return op - - @pytest.mark.it("Cancels and clears the operation's timeout timer") - def test_complete_before_timeout(self, mocker, stage, op, mock_timer): - # Apply the timer - stage.run_op(op) - assert not op.completed - assert mock_timer.call_count == 1 - mock_timer_inst = op.timeout_timer - assert mock_timer_inst is mock_timer.return_value - assert mock_timer_inst.cancel.call_count == 0 - - # Complete the operation - op.complete() - - # Timer is now cancelled and cleared - assert mock_timer_inst.cancel.call_count == 1 - assert mock_timer_inst.cancel.call_args == mocker.call() - assert op.timeout_timer is None - - -############### -# RETRY STAGE # -############### - -# Tuples of classname + args -retryable_ops = [ - (pipeline_ops_mqtt.MQTTSubscribeOperation, {"topic": "fake_topic", "callback": fake_callback}), - ( - pipeline_ops_mqtt.MQTTUnsubscribeOperation, - {"topic": "fake_topic", "callback": fake_callback}, - ), -] - -retryable_exceptions = [pipeline_exceptions.OperationTimeout] - - -class RetryStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_base.RetryStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - mocker.spy(stage, "run_op") - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class RetryStageInstantiationTests(RetryStageTestConfig): - # TODO: this will no longer be necessary once these are implemented as part of a more robust retry policy - @pytest.mark.it( - "Sets default retry intervals to 20 seconds for MQTTSubscribeOperation and MQTTUnsubscribeOperation" - ) - def test_retry_intervals(self, init_kwargs): - stage = pipeline_stages_base.RetryStage(**init_kwargs) - assert stage.retry_intervals[pipeline_ops_mqtt.MQTTSubscribeOperation] == 20 - assert stage.retry_intervals[pipeline_ops_mqtt.MQTTUnsubscribeOperation] == 20 - - @pytest.mark.it("Initializes 'ops_waiting_to_retry' as an empty list") - def test_ops_waiting_to_retry(self, init_kwargs): - stage = pipeline_stages_base.RetryStage(**init_kwargs) - assert stage.ops_waiting_to_retry == [] - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_base.RetryStage, - stage_test_config_class=RetryStageTestConfig, - extended_stage_instantiation_test_class=RetryStageInstantiationTests, -) - - -# NOTE: Although there is a branch in the implementation that distinguishes between -# retryable operations, and non-retryable operations, with retryable operations having -# a callback added, this is not captured in this test, as callback resolution is tested -# in a different unit. -@pytest.mark.describe("RetryStage - .run_op()") -class TestRetryStageRunOp(RetryStageTestConfig, StageRunOpTestBase): - ops = retryable_ops + [(ArbitraryOperation, {"callback": fake_callback})] - - @pytest.fixture(params=ops, ids=[x[0].__name__ for x in ops]) - def op(self, request, mocker): - op_cls = request.param[0] - init_kwargs = request.param[1] - return op_cls(**init_kwargs) - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "RetryStage - OCCURRENCE: Retryable operation completes unsuccessfully with a retryable error after call to .run_op()" -) -class TestRetryStageRetryableOperationCompletedWithRetryableError(RetryStageTestConfig): - @pytest.fixture(params=retryable_ops, ids=[x[0].__name__ for x in retryable_ops]) - def op(self, request, mocker): - op_cls = request.param[0] - init_kwargs = request.param[1] - return op_cls(**init_kwargs) - - @pytest.fixture(params=retryable_exceptions) - def error(self, request): - return request.param() - - @pytest.mark.it("Halts operation completion") - def test_halt(self, mocker, stage, op, error, mock_timer): - stage.run_op(op) - op.complete(error=error) - - assert not op.completed - - @pytest.mark.it( - "Adds a retry timer to the operation with the interval specified for the operation by the configuration, and starts it" - ) - def test_timer(self, mocker, stage, op, error, mock_timer): - stage.run_op(op) - op.complete(error=error) - - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call(stage.retry_intervals[type(op)], mocker.ANY) - assert op.retry_timer is mock_timer.return_value - assert op.retry_timer.start.call_count == 1 - assert op.retry_timer.start.call_args == mocker.call() - - @pytest.mark.it( - "Adds the operation to the list of 'ops_waiting_to_retry' only for the duration of the timer" - ) - def test_adds_to_waiting_list_during_timer(self, mocker, stage, op, error, mock_timer): - stage.run_op(op) - - # The op is not listed as waiting for retry before completion - assert op not in stage.ops_waiting_to_retry - - # Completing the op starts the timer - op.complete(error=error) - assert mock_timer.call_count == 1 - timer_callback = mock_timer.call_args[0][1] - assert mock_timer.return_value.start.call_count == 1 - - # Once completed and the timer has been started, the op IS listed as waiting for retry - assert op in stage.ops_waiting_to_retry - - # Simulate timer completion - timer_callback() - - # Once the timer is completed, the op is no longer listed as waiting for retry - assert op not in stage.ops_waiting_to_retry - - @pytest.mark.it("Re-runs the operation after the retry timer expires") - def test_reruns(self, mocker, stage, op, error, mock_timer): - stage.run_op(op) - op.complete(error=error) - - assert stage.run_op.call_count == 1 - assert mock_timer.call_count == 1 - timer_callback = mock_timer.call_args[0][1] - - # Simulate timer completion - timer_callback() - - # run_op was called again - assert stage.run_op.call_count == 2 - - @pytest.mark.it("Cancels and clears the retry timer after the retry timer expires") - def test_clears_retry_timer(self, mocker, stage, op, error, mock_timer): - stage.run_op(op) - op.complete(error=error) - timer_callback = mock_timer.call_args[0][1] - - assert mock_timer.cancel.call_count == 0 - assert op.retry_timer is mock_timer.return_value - - # Simulate timer completion - timer_callback() - - assert mock_timer.return_value.cancel.call_count == 1 - assert mock_timer.return_value.cancel.call_args == mocker.call() - assert op.retry_timer is None - - @pytest.mark.it( - "Adds a new retry timer to the re-run operation, if it completes unsuccessfully again" - ) - def test_rerun_op_unsuccessful_again(self, mocker, stage, op, error, mock_timer): - stage.run_op(op) - assert stage.run_op.call_count == 1 - - # Complete with failure the first time - op.complete(error=error) - - assert mock_timer.call_count == 1 - assert op.retry_timer is mock_timer.return_value - timer_callback1 = mock_timer.call_args[0][1] - - # Trigger retry - timer_callback1() - - assert stage.run_op.call_count == 2 - assert stage.run_op.call_args == mocker.call(op) - assert op.retry_timer is None - - # Complete with failure the second time - op.complete(error=error) - - assert mock_timer.call_count == 2 - assert op.retry_timer is mock_timer.return_value - timer_callback2 = mock_timer.call_args[0][1] - - # Trigger retry again - timer_callback2() - - assert stage.run_op.call_count == 3 - assert stage.run_op.call_args == mocker.call(op) - assert op.retry_timer is None - - @pytest.mark.it("Supports multiple simultaneous operations retrying") - def test_multiple_retries(self, mocker, stage, mock_timer): - op1 = pipeline_ops_mqtt.MQTTSubscribeOperation( - topic="fake_topic_1", callback=mocker.MagicMock() - ) - op2 = pipeline_ops_mqtt.MQTTSubscribeOperation( - topic="fake_topic_2", callback=mocker.MagicMock() - ) - op3 = pipeline_ops_mqtt.MQTTUnsubscribeOperation( - topic="fake_topic_3", callback=mocker.MagicMock() - ) - - stage.run_op(op1) - stage.run_op(op2) - stage.run_op(op3) - assert stage.run_op.call_count == 3 - - assert not op1.completed - assert not op2.completed - assert not op3.completed - - op1.complete(error=pipeline_exceptions.OperationTimeout()) - op2.complete(error=pipeline_exceptions.OperationTimeout()) - op3.complete(error=pipeline_exceptions.OperationTimeout()) - - # Ops halted - assert not op1.completed - assert not op2.completed - assert not op3.completed - - # Timers set - assert mock_timer.call_count == 3 - assert op1.retry_timer is mock_timer.return_value - assert op2.retry_timer is mock_timer.return_value - assert op3.retry_timer is mock_timer.return_value - assert mock_timer.return_value.start.call_count == 3 - - # Operations awaiting retry - assert op1 in stage.ops_waiting_to_retry - assert op2 in stage.ops_waiting_to_retry - assert op3 in stage.ops_waiting_to_retry - - timer1_complete = mock_timer.call_args_list[0][0][1] - timer2_complete = mock_timer.call_args_list[1][0][1] - timer3_complete = mock_timer.call_args_list[2][0][1] - - # Trigger op1's timer to complete - timer1_complete() - - # Only op1 was re-run, and had it's timer removed - assert mock_timer.return_value.cancel.call_count == 1 - assert op1.retry_timer is None - assert op1 not in stage.ops_waiting_to_retry - assert op2.retry_timer is mock_timer.return_value - assert op2 in stage.ops_waiting_to_retry - assert op3.retry_timer is mock_timer.return_value - assert op3 in stage.ops_waiting_to_retry - assert stage.run_op.call_count == 4 - assert stage.run_op.call_args == mocker.call(op1) - - # Trigger op2's timer to complete - timer2_complete() - - # Only op2 was re-run and had it's timer removed - assert mock_timer.return_value.cancel.call_count == 2 - assert op2.retry_timer is None - assert op2 not in stage.ops_waiting_to_retry - assert op3.retry_timer is mock_timer.return_value - assert op3 in stage.ops_waiting_to_retry - assert stage.run_op.call_count == 5 - assert stage.run_op.call_args == mocker.call(op2) - - # Trigger op3's timer to complete - timer3_complete() - - # op3 has now also been re-run and had it's timer removed - assert op3.retry_timer is None - assert op3 not in stage.ops_waiting_to_retry - assert stage.run_op.call_count == 6 - assert stage.run_op.call_args == mocker.call(op3) - - -@pytest.mark.describe( - "RetryStage - OCCURRENCE: Retryable operation completes unsuccessfully with a non-retryable error after call to .run_op()" -) -class TestRetryStageRetryableOperationCompletedWithNonRetryableError(RetryStageTestConfig): - @pytest.fixture(params=retryable_ops, ids=[x[0].__name__ for x in retryable_ops]) - def op(self, request, mocker): - op_cls = request.param[0] - init_kwargs = request.param[1] - return op_cls(**init_kwargs) - - @pytest.fixture - def error(self, arbitrary_exception): - return arbitrary_exception - - @pytest.mark.it("Completes normally without retry") - def test_no_retry(self, mocker, stage, op, error, mock_timer): - stage.run_op(op) - op.complete(error=error) - - assert op.completed - assert op not in stage.ops_waiting_to_retry - assert mock_timer.call_count == 0 - - @pytest.mark.it("Cancels and clears the operation's retry timer, if one exists") - def test_cancels_existing_timer(self, mocker, stage, op, error, mock_timer): - # NOTE: This shouldn't happen naturally. We have to artificially create this circumstance - stage.run_op(op) - - # Artificially add a timer. Note that this is already mocked due to the 'mock_timer' fixture - op.retry_timer = threading.Timer(20, fake_callback) - assert op.retry_timer is mock_timer.return_value - - op.complete(error=error) - - assert op.completed - assert mock_timer.return_value.cancel.call_count == 1 - assert op.retry_timer is None - - -@pytest.mark.describe( - "RetryStage - OCCURRENCE: Retryable operation completes successfully after call to .run_op()" -) -class TestRetryStageRetryableOperationCompletedSuccessfully(RetryStageTestConfig): - @pytest.fixture(params=retryable_ops, ids=[x[0].__name__ for x in retryable_ops]) - def op(self, request, mocker): - op_cls = request.param[0] - init_kwargs = request.param[1] - return op_cls(**init_kwargs) - - @pytest.mark.it("Completes normally without retry") - def test_no_retry(self, mocker, stage, op, mock_timer): - stage.run_op(op) - op.complete() - - assert op.completed - assert op not in stage.ops_waiting_to_retry - assert mock_timer.call_count == 0 - - # NOTE: this isn't doing anything because arb ops don't trigger callback - @pytest.mark.it("Cancels and clears the operation's retry timer, if one exists") - def test_cancels_existing_timer(self, mocker, stage, op, mock_timer): - # NOTE: This shouldn't happen naturally. We have to artificially create this circumstance - stage.run_op(op) - - # Artificially add a timer. Note that this is already mocked due to the 'mock_timer' fixture - op.retry_timer = threading.Timer(20, fake_callback) - assert op.retry_timer is mock_timer.return_value - - op.complete() - - assert op.completed - assert mock_timer.return_value.cancel.call_count == 1 - assert op.retry_timer is None - - -@pytest.mark.describe( - "RetryStage - OCCURRENCE: Non-retryable operation completes after call to .run_op()" -) -class TestRetryStageNonretryableOperationCompleted(RetryStageTestConfig): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Completes normally without retry, if completed successfully") - def test_successful_completion(self, mocker, stage, op, mock_timer): - stage.run_op(op) - op.complete() - - assert op.completed - assert op not in stage.ops_waiting_to_retry - assert mock_timer.call_count == 0 - - @pytest.mark.it( - "Completes normally without retry, if completed unsuccessfully with a non-retryable exception" - ) - def test_unsuccessful_non_retryable_err( - self, mocker, stage, op, arbitrary_exception, mock_timer - ): - stage.run_op(op) - op.complete(error=arbitrary_exception) - - assert op.completed - assert op not in stage.ops_waiting_to_retry - assert mock_timer.call_count == 0 - - @pytest.mark.it( - "Completes normally without retry, if completed unsuccessfully with a retryable exception" - ) - @pytest.mark.parametrize("exception", retryable_exceptions) - def test_unsuccessful_retryable_err(self, mocker, stage, op, exception, mock_timer): - stage.run_op(op) - op.complete(error=exception) - - assert op.completed - assert op not in stage.ops_waiting_to_retry - assert mock_timer.call_count == 0 - - -################### -# RECONNECT STAGE # -################### - - -class ConnectionStateStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_base.ConnectionStateStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.connection_retry_interval = 1234 - mocker.spy(stage, "run_op") - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class ConnectionStateStageInstantiationTests(ConnectionStateStageTestConfig): - @pytest.mark.it("Initializes the 'reconnect_timer' attribute as None") - def test_reconnect_timer(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.reconnect_timer is None - - @pytest.mark.it("Initializes the 'waiting_ops' queue") - def test_waiting_connect_ops(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert isinstance(stage.waiting_ops, queue.Queue) - assert stage.waiting_ops.empty() - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_base.ConnectionStateStage, - stage_test_config_class=ConnectionStateStageTestConfig, - extended_stage_instantiation_test_class=ConnectionStateStageInstantiationTests, -) - - -@pytest.mark.describe("ConnectionStateStage - .run_op() -- Called with ConnectOperation") -class TestConnectionStateStageRunOpWithConnectOperation( - ConnectionStateStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - - @pytest.mark.it( - "Adds the operation to the `waiting_ops` queue and does nothing else if the pipeline connection is in an intermediate state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTING, - ConnectionState.REAUTHORIZING, - ], - ) - def test_intermediate_state(self, stage, op, state): - stage.nucleus.connection_state = state - assert stage.waiting_ops.empty() - - stage.run_op(op) - - assert not stage.waiting_ops.empty() - assert stage.waiting_ops.qsize() == 1 - assert stage.waiting_ops.get() is op - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Completes the operation without changing the state if the pipeline is already in a CONNECTED state" - ) - def test_connected_state_change(self, stage, op): - stage.nucleus.connection_state = ConnectionState.CONNECTED - assert not op.completed - - stage.run_op(op) - - assert op.completed - assert stage.nucleus.connection_state is ConnectionState.CONNECTED - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Changes the state to CONNECTING and sends the operation down the pipeline if the pipeline is in a DISCONNECTED state" - ) - def test_disconnected_state_change(self, mocker, stage, op): - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it( - "Sets the state to DISCONNECTED if the operation sent down the pipeline completes with error" - ) - def test_op_completes_error(self, stage, op, arbitrary_exception): - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - - op.complete(arbitrary_exception) - - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - - @pytest.mark.it( - "Does not change the state if the operation sent down the pipeline completes successfully" - ) - def test_op_completes_success(self, stage, op): - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - op.complete() - - # NOTE: This is a very weird test in that this would never happen like this "in the wild" - # In a real scenario, prior to the operation completing, an event would fire, and that - # event would cause a state change to the desired state, so the state would actually not - # still be this in this "modified state" that was a result of the operation running through - # the pipeline. However, that's kind of the important thing we need to test - that the - # operation completing DOES NOT change the state, because that's not it's job. So in order - # to show this, we will NOT emulate the state change that occurs from the event. Just - # remember that in practice, the state would not actually still be the modified state, but - # instead the desired goal state - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - - @pytest.mark.it( - "Re-runs all of the ops in the `waiting_ops` queue (if any) upon completion of the op after it is sent down" - ) - @pytest.mark.parametrize( - "queued_ops", - [ - pytest.param( - [pipeline_ops_base.DisconnectOperation(callback=None)], id="Single op waiting" - ), - pytest.param( - [ - pipeline_ops_base.ReauthorizeConnectionOperation(callback=None), - pipeline_ops_base.ConnectOperation(callback=None), - ], - id="Multiple ops waiting", - ), - ], - ) - @pytest.mark.parametrize( - "success", - [ - pytest.param(True, id="Operation completes with success"), - pytest.param(False, id="Operation completes with error"), - ], - ) - def test_op_completes_causes_waiting_rerun( - self, mocker, stage, op, queued_ops, success, arbitrary_exception - ): - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - # Before completion, more ops come down and queue up - for queued_op in queued_ops: - stage.run_op(queued_op) - assert stage.waiting_ops.qsize() == len(queued_ops) - - # Now mock out run_op so we can see if it gets called - stage.run_op = mocker.MagicMock() - - # As mentioned above, before operations complete successfully, an event will be fired, - # and this event will trigger state change. We need to emulate this one here so that - # the waiting ops will not end up requeued. - if success: - stage.nucleus.connection_state = ConnectionState.CONNECTED - op.complete() - else: - op.complete(arbitrary_exception) - - # All items were removed from the waiting queue and run on the stage - assert stage.waiting_ops.qsize() == 0 - assert stage.run_op.call_count == len(queued_ops) - for i in range(len(queued_ops)): - assert stage.run_op.call_args_list[i] == mocker.call(queued_ops[i]) - - -@pytest.mark.describe("ConnectionStateStage - .run_op() -- Called with DisconnectOperation") -class TestConnectionStateStageRunOpWithDisconnectOperation( - ConnectionStateStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - - @pytest.mark.it( - "Adds the operation to the `waiting_ops` queue and does nothing else if the pipeline connection is in an intermediate state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTING, - ConnectionState.REAUTHORIZING, - ], - ) - def test_intermediate_state(self, stage, op, state): - stage.nucleus.connection_state = state - assert stage.waiting_ops.empty() - - stage.run_op(op) - - assert not stage.waiting_ops.empty() - assert stage.waiting_ops.qsize() == 1 - assert stage.waiting_ops.get() is op - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Clears any reconnection timer that may exist if the pipeline connection is in a stable state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED, - ], - ) - def test_clears_reconnect_timer(self, mocker, stage, op, state): - stage.nucleus.connection_state = state - timer_mock = mocker.MagicMock() - stage.reconnect_timer = timer_mock - - stage.run_op(op) - - assert stage.reconnect_timer is None - assert timer_mock.cancel.call_count == 1 - - @pytest.mark.it( - "Completes the operation without changing the state if the pipeline is already in a DISCONNECTED state" - ) - def test_connected_state_change(self, stage, op): - assert not op.completed - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - - stage.run_op(op) - - assert op.completed - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Changes the state to DISCONNECTING and sends the operation down the pipeline if the pipeline is in a CONNECTED state" - ) - def test_disconnected_state_change(self, mocker, stage, op): - stage.nucleus.connection_state = ConnectionState.CONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTING - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it( - "Sets the state to DISCONNECTED if the operation sent down the pipeline completes with error" - ) - def test_op_completes_error(self, stage, op, arbitrary_exception): - stage.nucleus.connection_state = ConnectionState.CONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTING - - op.complete(arbitrary_exception) - - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - - @pytest.mark.it( - "Does not change the state if the operation sent down the pipeline completes successfully" - ) - def test_op_completes_success(self, stage, op): - stage.nucleus.connection_state = ConnectionState.CONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTING - - op.complete() - - # NOTE: This is a very weird test in that this would never happen like this "in the wild" - # In a real scenario, prior to the operation completing, an event would fire, and that - # event would cause a state change to the desired state, so the state would actually not - # still be this in this "modified state" that was a result of the operation running through - # the pipeline. However, that's kind of the important thing we need to test - that the - # operation completing DOES NOT change the state, because that's not it's job. So in order - # to show this, we will NOT emulate the state change that occurs from the event. Just - # remember that in practice, the state would not actually still be the modified state, but - # instead the desired goal state - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTING - - @pytest.mark.it( - "Re-runs all the waiting ops in the `waiting_ops` queue (if any) upon completion of the op after it is sent down" - ) - @pytest.mark.parametrize( - "queued_ops", - [ - pytest.param( - [pipeline_ops_base.DisconnectOperation(callback=None)], id="Single op waiting" - ), - pytest.param( - [ - pipeline_ops_base.ReauthorizeConnectionOperation(callback=None), - pipeline_ops_base.ConnectOperation(callback=None), - ], - id="Multiple ops waiting", - ), - ], - ) - @pytest.mark.parametrize( - "success", - [ - pytest.param(True, id="Operation completes with success"), - pytest.param(False, id="Operation completes with error"), - ], - ) - def test_op_completes_causes_waiting_rerun( - self, mocker, stage, op, queued_ops, success, arbitrary_exception - ): - stage.nucleus.connection_state = ConnectionState.CONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTING - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - # Before completion, more ops come down and queue up - for queued_op in queued_ops: - stage.run_op(queued_op) - assert stage.waiting_ops.qsize() == len(queued_ops) - - # Now mock out run_op so we can see if it gets called - stage.run_op = mocker.MagicMock() - - # As mentioned above, before operations complete successfully, an event will be fired, - # and this event will trigger state change. We need to emulate this one here so that - # the waiting ops will not end up requeued. - if success: - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - op.complete() - else: - op.complete(arbitrary_exception) - - # All items were removed from the waiting queue and run on the stage - assert stage.waiting_ops.qsize() == 0 - assert stage.run_op.call_count == len(queued_ops) - for i in range(len(queued_ops)): - assert stage.run_op.call_args_list[i] == mocker.call(queued_ops[i]) - - -@pytest.mark.describe( - "ConnectionStateStage - .run_op() -- Called with ReauthorizeConnectionOperation" -) -class TestConnectionStateStageRunOpWithReauthorizeConnectionOperation( - ConnectionStateStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) - - @pytest.mark.it( - "Adds the operation to the `waiting_ops` queue and does nothing else if the pipeline is in an intermediate state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTING, - ConnectionState.REAUTHORIZING, - ], - ) - def test_intermediate_state(self, stage, op, state): - stage.nucleus.connection_state = state - assert stage.waiting_ops.empty() - - stage.run_op(op) - - assert not stage.waiting_ops.empty() - assert stage.waiting_ops.qsize() == 1 - assert stage.waiting_ops.get() is op - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Changes the state to REAUTHORIZING and sends the operation down the pipeline if the pipeline is in a CONNECTED state" - ) - def test_connected_state_change(self, mocker, stage, op): - stage.nucleus.connection_state = ConnectionState.CONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.REAUTHORIZING - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it( - "Changes the state to REAUTHORIZING and sends the operation down the pipeline if the pipeline is in a DISCONNECTED state" - ) - def test_disconnected_state_change(self, mocker, stage, op): - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.REAUTHORIZING - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it( - "Sets the state to DISCONNECTED if the operation sent down the pipeline completes with error" - ) - @pytest.mark.parametrize( - "original_state, modified_state", - [ - pytest.param( - ConnectionState.CONNECTED, - ConnectionState.REAUTHORIZING, - id="CONNECTED->REAUTHORIZING", - ), - pytest.param( - ConnectionState.DISCONNECTED, - ConnectionState.REAUTHORIZING, - id="DISCONNECTED->REAUTHORIZING", - ), - ], - ) - def test_op_completes_error( - self, stage, op, original_state, modified_state, arbitrary_exception - ): - stage.nucleus.connection_state = original_state - stage.run_op(op) - assert stage.nucleus.connection_state == modified_state - - op.complete(arbitrary_exception) - - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - - @pytest.mark.it( - "Does not change the state if the operation sent down the pipeline completes successfully" - ) - @pytest.mark.parametrize( - "original_state, modified_state", - [ - pytest.param( - ConnectionState.CONNECTED, - ConnectionState.REAUTHORIZING, - id="CONNECTED->REAUTHORIZING", - ), - pytest.param( - ConnectionState.DISCONNECTED, - ConnectionState.REAUTHORIZING, - id="DISCONNECTED->REAUTHORIZING", - ), - ], - ) - def test_op_completes_success(self, stage, op, original_state, modified_state): - stage.nucleus.connection_state = original_state - stage.run_op(op) - assert stage.nucleus.connection_state == modified_state - - op.complete() - - # NOTE: This is a very weird test in that this would never happen like this "in the wild" - # In a real scenario, prior to the operation completing, an event would fire, and that - # event would cause a state change to the desired state, so the state would actually not - # still be this in this "modified state" that was a result of the operation running through - # the pipeline. However, that's kind of the important thing we need to test - that the - # operation completing DOES NOT change the state, because that's not it's job. So in order - # to show this, we will NOT emulate the state change that occurs from the event. Just - # remember that in practice, the state would not actually still be the modified state, but - # instead the desired goal state - assert stage.nucleus.connection_state == modified_state - - @pytest.mark.it( - "Re-runs all of the ops in the `waiting_ops` queue (if any) upon completion of the op after it is sent down" - ) - @pytest.mark.parametrize( - "queued_ops", - [ - pytest.param( - [pipeline_ops_base.DisconnectOperation(callback=None)], id="Single op waiting" - ), - pytest.param( - [ - pipeline_ops_base.ReauthorizeConnectionOperation(callback=None), - pipeline_ops_base.ConnectOperation(callback=None), - ], - id="Multiple ops waiting", - ), - ], - ) - @pytest.mark.parametrize( - "success", - [ - pytest.param(True, id="Operation completes with success"), - pytest.param(False, id="Operation completes with error"), - ], - ) - def test_op_completes_causes_waiting_rerun( - self, mocker, stage, op, queued_ops, success, arbitrary_exception - ): - stage.nucleus.connection_state = ConnectionState.CONNECTED - stage.run_op(op) - assert stage.nucleus.connection_state is ConnectionState.REAUTHORIZING - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - # Before completion, more ops come down and queue up - for queued_op in queued_ops: - stage.run_op(queued_op) - assert stage.waiting_ops.qsize() == len(queued_ops) - - # Now mock out run_op so we can see if it gets called - stage.run_op = mocker.MagicMock() - - # As mentioned above, before operations complete successfully, an event will be fired, - # and this event will trigger state change. We need to emulate this one here so that - # the waiting ops will not end up requeued. - if success: - stage.nucleus.connection_state = ConnectionState.CONNECTED - op.complete() - else: - op.complete(arbitrary_exception) - - # All items were removed from the waiting queue and run on the stage - assert stage.waiting_ops.qsize() == 0 - assert stage.run_op.call_count == len(queued_ops) - for i in range(len(queued_ops)): - assert stage.run_op.call_args_list[i] == mocker.call(queued_ops[i]) - - -@pytest.mark.describe("ConnectionStateStage - .run_op() -- Called with ShutdownPipelineOperation") -class TestConnectionStateStageRunOpWithShutdownPipelineOperation( - ConnectionStateStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ShutdownPipelineOperation(callback=mocker.MagicMock()) - - @pytest.mark.it( - "Adds the operation to the `waiting_ops` queue and does nothing else if the pipeline connection is in an intermediate state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTING, - ConnectionState.REAUTHORIZING, - ], - ) - def test_intermediate_state(self, stage, op, state): - stage.nucleus.connection_state = state - assert stage.waiting_ops.empty() - - stage.run_op(op) - - assert not stage.waiting_ops.empty() - assert stage.waiting_ops.qsize() == 1 - assert stage.waiting_ops.get() is op - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Clears any reconnection timer that may exist if the pipeline connection is in a stable state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED, - ], - ) - def test_timer_clear(self, mocker, op, stage, state): - stage.nucleus.connection_state = state - timer_mock = mocker.MagicMock() - stage.reconnect_timer = timer_mock - - stage.run_op(op) - - assert timer_mock.cancel.call_count == 1 - assert stage.reconnect_timer is None - - @pytest.mark.it( - "Cancels any operations in the `waiting_ops` queue if the pipeline connection is in a stable state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED, - ], - ) - def test_waiting_ops_cancellation(self, mocker, op, stage, state): - stage.nucleus.connection_state = state - waiting_op1 = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - waiting_op2 = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - waiting_op3 = pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) - stage.waiting_ops.put_nowait(waiting_op1) - stage.waiting_ops.put_nowait(waiting_op2) - stage.waiting_ops.put_nowait(waiting_op3) - - stage.run_op(op) - - assert stage.waiting_ops.empty() - assert waiting_op1.completed - assert isinstance(waiting_op1.error, pipeline_exceptions.OperationCancelled) - assert waiting_op2.completed - assert isinstance(waiting_op2.error, pipeline_exceptions.OperationCancelled) - assert waiting_op3.completed - assert isinstance(waiting_op3.error, pipeline_exceptions.OperationCancelled) - - @pytest.mark.it( - "Sends the operation down the pipeline without changing the state if the pipeline connection is in a stable state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED, - ], - ) - def test_sends_op_down(self, mocker, op, stage, state): - stage.nucleus.connection_state = state - - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - assert stage.nucleus.connection_state is state - - -@pytest.mark.describe("ConnectionStateStage - .run_op() -- Called with arbitrary other operation") -class TestConnectionStateStageRunOpWithArbitraryOperation( - ConnectionStateStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it( - "Sends the operation down the pipeline without changing the state if the pipeline is in a stable state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED, - ], - ) - def test_stable_state(self, mocker, op, stage, state): - stage.nucleus.connection_state = state - - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - assert stage.nucleus.connection_state is state - - @pytest.mark.it( - "Adds the operation to the `waiting_ops` queue and does nothing else if the pipeline is in an intermediate state" - ) - @pytest.mark.parametrize( - "state", - [ConnectionState.CONNECTING, ConnectionState.DISCONNECTING, ConnectionState.REAUTHORIZING], - ) - def test_intermediate_state(self, op, stage, state): - stage.nucleus.connection_state = state - assert stage.waiting_ops.empty() - - stage.run_op(op) - - assert not stage.waiting_ops.empty() - assert stage.waiting_ops.qsize() == 1 - assert stage.waiting_ops.get() is op - assert stage.send_op_down.call_count == 0 - - -@pytest.mark.describe( - "ConnectionStateStage - .handle_pipeline_event() -- Called with ConnectedEvent" -) -class TestConnectionStateStageHandlePipelineEventCalledWithConnectedEvent( - ConnectionStateStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self): - return pipeline_events_base.ConnectedEvent() - - @pytest.mark.it("Clears any reconnect timer that may exist") - @pytest.mark.parametrize( - "state", - [ - # Valid states - ConnectionState.CONNECTING, - ConnectionState.REAUTHORIZING, - # Invalid states (still test tho) - ConnectionState.DISCONNECTING, - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED, - ], - ) - def test_clears_reconnect_timer(self, mocker, stage, event, state): - stage.nucleus.connection_state = state - mock_timer = mocker.MagicMock() - stage.reconnect_timer = mock_timer - - stage.handle_pipeline_event(event) - - assert stage.reconnect_timer is None - assert mock_timer.cancel.call_count == 1 - - @pytest.mark.it( - "Changes the state to CONNECTED and sends the event up the pipeline if in a CONNECTING state" - ) - def test_connecting_state(self, mocker, stage, event): - stage.nucleus.connection_state = ConnectionState.CONNECTING - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is ConnectionState.CONNECTED - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - @pytest.mark.it( - "Changes the state to CONNECTED and sends the event up the pipeline if in a REAUTHORIZING state" - ) - def test_reauthorizing_state(self, mocker, stage, event): - stage.nucleus.connection_state = ConnectionState.REAUTHORIZING - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is ConnectionState.CONNECTED - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - @pytest.mark.it( - "Changes the state to CONNECTED and sends the event up the pipeline if in an invalid state" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.DISCONNECTING, - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTED, - ], - ) - def test_invalid_states(self, mocker, stage, event, state): - # NOTE: This should never happen in practice - stage.nucleus.connection_state = state - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is ConnectionState.CONNECTED - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe( - "ConnectionStateStage - .handle_pipeline_event() -- Called with DisconnectedEvent" -) -class TestConnectionStateStageHandlePipelineEventCalledWithDisconnectedEvent( - ConnectionStateStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self): - return pipeline_events_base.DisconnectedEvent() - - @pytest.mark.it( - "Changes the state to DISCONNECTED and sends the event up the pipeline if in a CONNECTED state (i.e. Unexpected Disconnect)" - ) - def test_connected_state(self, mocker, stage, event, mock_timer): - # mock_timer is required here, even though it's unused so that we don't set a real timer - stage.nucleus.connection_state = ConnectionState.CONNECTED - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - @pytest.mark.it( - "Changes the state to DISCONNECTED and sends the event up the pipeline if in a DISCONNECTING state (i.e. Expected Disconnect - Disconnection process)" - ) - def test_disconnecting_state(self, mocker, stage, event): - stage.nucleus.connection_state = ConnectionState.DISCONNECTING - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - @pytest.mark.it( - "Does NOT change the state, but sends the event up the pipeline if in a REAUTHORIZING state (i.e. Expected Disconnect - Reauthorization process)" - ) - def test_reauthorizing_state(self, mocker, stage, event): - stage.nucleus.connection_state = ConnectionState.REAUTHORIZING - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is ConnectionState.REAUTHORIZING - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - @pytest.mark.it( - "Changes the state to DISCONNECTED and sends the event up the pipeline if in an invalid state" - ) - @pytest.mark.parametrize("state", [ConnectionState.DISCONNECTED, ConnectionState.CONNECTING]) - def test_invalid_states(self, mocker, stage, event, state): - # NOTE: This should never happen in practice - stage.nucleus.connection_state = state - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - @pytest.mark.it( - "Starts an immediate reconnect timer following an Unexpected Disconnect if Connection Retry is enabled" - ) - def test_reconnect_timer_created(self, mocker, stage, event, mock_timer): - stage.nucleus.pipeline_configuration.connection_retry = True - stage.nucleus.connection_state = ConnectionState.CONNECTED - assert stage.reconnect_timer is None - - stage.handle_pipeline_event(event) - - assert stage.reconnect_timer is mock_timer.return_value - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call(0.01, mocker.ANY) - assert mock_timer.return_value.start.call_count == 1 - - @pytest.mark.it("Does NOT start a reconnect timer under any other conditions") - @pytest.mark.parametrize( - "state, retry_enabled", - [ - pytest.param( - ConnectionState.CONNECTED, - False, - id="Unexpected Disconnect - Connection Retry Disabled", - ), - pytest.param( - ConnectionState.DISCONNECTING, - True, - id="Expected Disconnect (Disconnection Process) - Connection Retry Enabled", - ), - pytest.param( - ConnectionState.DISCONNECTING, - False, - id="Expected Disconnect (Disconnection Process) - Connection Retry Disabled", - ), - pytest.param( - ConnectionState.REAUTHORIZING, - True, - id="Expected Disconnect (Reauthorization Process) - Connection Retry Enabled", - ), - pytest.param( - ConnectionState.REAUTHORIZING, - False, - id="Expected Disconnect (Reauthorization Process) - Connection Retry Disabled", - ), - pytest.param( - ConnectionState.DISCONNECTED, - True, - id="Unexpected Disconnect (Invalid State: DISCONNECTED) - Connection Retry Enabled", - ), - pytest.param( - ConnectionState.DISCONNECTED, - False, - id="Unexpected Disconnect (Invalid State: DISCONNECTED) - Connection Retry Disabled", - ), - pytest.param( - ConnectionState.CONNECTING, - True, - id="Unexpected Disconnect (Invalid State: CONNECTING) - Connection Retry Enabled", - ), - pytest.param( - ConnectionState.CONNECTING, - False, - id="Unexpected Disconnect (Invalid State: CONNECTING) - Connection Retry Disabled", - ), - ], - ) - def test_no_reconnect_timer_creation(self, stage, event, state, retry_enabled, mock_timer): - stage.nucleus.pipeline_configuration.connection_retry = retry_enabled - stage.nucleus.connection_state = state - assert stage.reconnect_timer is None - - stage.handle_pipeline_event(event) - - assert stage.reconnect_timer is None - assert mock_timer.call_count == 0 - - -@pytest.mark.describe( - "ConnectionStateStage - .handle_pipeline_event() -- Called with arbitrary other event" -) -class TestConnectionStateStageHandlePipelineEventCalledWithArbitraryEvent( - ConnectionStateStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it( - "Sends the event up the pipeline without changing the state or starting a reconnect timer" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTING, - ConnectionState.CONNECTED, - ConnectionState.DISCONNECTING, - ConnectionState.DISCONNECTED, - ConnectionState.REAUTHORIZING, - ], - ) - def test_sends_event_up(self, mocker, stage, event, state): - stage.nucleus.connection_state = state - assert stage.reconnect_timer is None - - stage.handle_pipeline_event(event) - - assert stage.nucleus.connection_state is state - assert stage.reconnect_timer is None - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe("ConnectionStateStage - OCCURRENCE: Reconnect Timer Expires") -class TestConnectionStateStageOCCURRENCEReconnectTimerExpires(ConnectionStateStageTestConfig): - @pytest.fixture( - params=[ - "Timer created by unexpected disconnect", - "Timer created by reconnect punting due to in-progress op", - "Timer created by failed reconnection attempt", - ] - ) - def trigger_stage_retry_timer_completion(self, request, stage, mock_timer): - """This fixture is parametrized to get the retry timer completion trigger for every - possible way a reconnect timer could have been made. This may seem redundant given that - in the implementation it's pretty clear they all work the same, but ensuring that is true - is the point of parametrizing the fixture""" - - # The stage must be connected in order to set a reconnect timer - stage.nucleus.connection_state = ConnectionState.CONNECTED - # Send a DisconnectedEvent to the stage in order to set up the timer - stage.handle_pipeline_event(pipeline_events_base.DisconnectedEvent()) - - if request.param == "Timer created by unexpected disconnect": - # Get timer completion callback - assert mock_timer.call_count == 1 - timer_callback = mock_timer.call_args[0][1] - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - - elif request.param == "Timer created by reconnect punting due to in-progress op": - # Get first timer completion callback - assert mock_timer.call_count == 1 - timer_callback = mock_timer.call_args[0][1] - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - - # Make the stage have an in-progress op going - stage.nucleus.connection_state = ConnectionState.REAUTHORIZING - - # Invoke the timer completion (which will cause another timer to create) - timer_callback() - - # Get second timer completion callback - assert mock_timer.call_count == 2 - timer_callback = mock_timer.call_args[0][1] - elif request.param == "Timer created by failed reconnection attempt": - # Get first timer completion callback - assert mock_timer.call_count == 1 - timer_callback = mock_timer.call_args[0][1] - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - - # Complete the callback, triggering a reconnection - timer_callback() - - # Get the op that was sent down - assert stage.send_op_down.call_count == 1 - op = stage.send_op_down.call_args[0][0] - - # Fail the op with a transient error - op.complete(error=transport_exceptions.ConnectionFailedError()) - - # Get second timer completion callback - assert mock_timer.call_count == 2 - timer_callback = mock_timer.call_args[0][1] - - # Reset mock so none of this stuff counts in the test - mock_timer.reset_mock() - stage.send_op_down.reset_mock() - stage.send_event_up.reset_mock() - return timer_callback - - @pytest.mark.it( - "Sends a new ConnectOperation down the pipeline, changes the state to CONNECTING and clears the reconnect timer if timer expires and the state is DISCONNECTED (i.e. do a reconnect)" - ) - def test_disconnected_state(self, stage, trigger_stage_retry_timer_completion): - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - assert stage.reconnect_timer is not None - - trigger_stage_retry_timer_completion() - - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.ConnectOperation) - assert stage.reconnect_timer is None - - @pytest.mark.it( - "Start a new reconnect timer for the interval specified by the pipeline config, but do not change the state or send anything down the pipeline, if the timer expires and the state is an intermediate state (i.e. punt until later)" - ) - @pytest.mark.parametrize( - "state", - [ - ConnectionState.CONNECTING, - ConnectionState.DISCONNECTING, - ConnectionState.REAUTHORIZING, - ], - ) - def test_intermediate_state( - self, mocker, stage, trigger_stage_retry_timer_completion, state, mock_timer - ): - stage.nucleus.connection_state = state - # Have to replace the timer with a manual mock here because mocked classes always - # return the same object, but we want to show that the object is replaced, so need - # to make something different. - stage.reconnect_timer = mocker.MagicMock() - old_reconnect_timer = stage.reconnect_timer - - trigger_stage_retry_timer_completion() - - assert stage.nucleus.connection_state is state - assert stage.send_op_down.call_count == 0 - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call( - stage.nucleus.pipeline_configuration.connection_retry_interval, mocker.ANY - ) - assert stage.reconnect_timer is mock_timer.return_value - assert stage.reconnect_timer is not old_reconnect_timer - - @pytest.mark.it( - "Does not change the state or send anything down the pipeline or start any timers if in an invalid state" - ) - @pytest.mark.parametrize("state", [ConnectionState.CONNECTED]) - def test_invalid_states(self, stage, trigger_stage_retry_timer_completion, state, mock_timer): - # This should never happen in practice - stage.nucleus.connection_state = state - - trigger_stage_retry_timer_completion() - - assert stage.nucleus.connection_state is state - assert stage.send_op_down.call_count == 0 - assert mock_timer.call_count == 0 - assert stage.reconnect_timer is None - - -@pytest.mark.describe("ConnectionStateStage - OCCURRENCE: Reconnection Completes") -class TestConnectionStateStageOCCURRENCEReconnectionCompletes(ConnectionStateStageTestConfig): - @pytest.fixture( - params=[ - "Reconnect after unexpected disconnect", - "Reconnect after punted reconnection", - "Reconnect after failed reconnection", - ] - ) - def reconnect_op(self, request, stage, mock_timer): - """This fixture is parametrized to cover all possible sources of a reconnection to show - that reconnections behave the same, no matter how they are generated. - """ - # The stage must be connected and then lose connection in order to set - # the initial reconnect timer - assert mock_timer.call_count == 0 - stage.nucleus.connection_state = ConnectionState.CONNECTED - stage.handle_pipeline_event(pipeline_events_base.DisconnectedEvent()) - if request.param == "Reconnect after unexpected disconnect": - # Invoke the callback passed to the reconnect timer to spawn op - assert mock_timer.call_count == 1 - assert stage.send_op_down.call_count == 0 - timer_callback = mock_timer.call_args[0][1] - timer_callback() - # Get the reconnect op - assert stage.send_op_down.call_count == 1 - reconnect_op = stage.send_op_down.call_args[0][0] - elif request.param == "Reconnect after punted reconnection": - # Change the state to an in-progress one so that reconnect will punt - stage.nucleus.connection_state = ConnectionState.CONNECTING - # Invoke the callback passed to the reconnect timer - assert mock_timer.call_count == 1 - assert stage.send_op_down.call_count == 0 - timer_callback = mock_timer.call_args[0][1] - timer_callback() - # Reconnection punted (set a new timer) - assert stage.send_op_down.call_count == 0 - assert mock_timer.call_count == 2 - # Change state so next reconnect will not punt - stage.nucleus.connection_state = ConnectionState.DISCONNECTED - # Invoke the callback passed to the new reconnect timer - timer_callback = mock_timer.call_args[0][1] - timer_callback() - # Get the reconnect op - assert stage.send_op_down.call_count == 1 - reconnect_op = stage.send_op_down.call_args[0][0] - elif request.param == "Reconnect after failed reconnection": - # Invoke the callback passed to the reconnect timer to spawn op - assert mock_timer.call_count == 1 - assert stage.send_op_down.call_count == 0 - timer_callback = mock_timer.call_args[0][1] - timer_callback() - # Fail the resulting reconnect op with a transient error - assert stage.send_op_down.call_count == 1 - assert stage.report_background_exception.call_count == 0 - reconnect_op = stage.send_op_down.call_args[0][0] - reconnect_op.complete(error=transport_exceptions.ConnectionFailedError()) - # New reconnect timer set - assert mock_timer.call_count == 2 - assert stage.send_op_down.call_count == 1 - assert stage.report_background_exception.call_count == 1 - # Invoke the callback passed to the new reconnect timer to spawn op - timer_callback = mock_timer.call_args[0][1] - timer_callback() - # Get the reconnect op for this second attempt - assert stage.send_op_down.call_count == 2 - reconnect_op = stage.send_op_down.call_args[0][0] - - # Clean up mocks - mock_timer.reset_mock() - stage.send_op_down.reset_mock() - stage.send_event_up.reset_mock() - stage.report_background_exception.reset_mock() - return reconnect_op - - @pytest.mark.it("Re-runs all of the ops in the `waiting_ops` queue (if any)") - @pytest.mark.parametrize( - "queued_ops", - [ - pytest.param( - [pipeline_ops_base.DisconnectOperation(callback=None)], id="Single op waiting" - ), - pytest.param( - [ - pipeline_ops_base.ReauthorizeConnectionOperation(callback=None), - pipeline_ops_base.ConnectOperation(callback=None), - ], - id="Multiple ops waiting", - ), - ], - ) - @pytest.mark.parametrize( - "success", - [ - pytest.param(True, id="Operation completes with success"), - pytest.param(False, id="Operation completes with error"), - ], - ) - def test_waiting_rerun( - self, mocker, stage, reconnect_op, queued_ops, success, arbitrary_exception - ): - # Before completion, more ops come down and queue up - for queued_op in queued_ops: - stage.run_op(queued_op) - assert stage.waiting_ops.qsize() == len(queued_ops) - - # Now mock out run_op so we can see if it gets called - stage.run_op = mocker.MagicMock() - - # Before operations complete successfully, an event will be fired, and this event will - # trigger state change. We need to emulate this here so that the waiting ops will - # not end up requeued. - if success: - stage.nucleus.connection_state = ConnectionState.CONNECTED - reconnect_op.complete() - else: - reconnect_op.complete(arbitrary_exception) - - # All items were removed from the waiting queue and run on the stage - assert stage.waiting_ops.qsize() == 0 - assert stage.run_op.call_count == len(queued_ops) - for i in range(len(queued_ops)): - assert stage.run_op.call_args_list[i] == mocker.call(queued_ops[i]) - - @pytest.mark.it("Reports the error as a background exception if completed with error") - def test_failure_report_background_exception( - self, mocker, stage, reconnect_op, arbitrary_exception - ): - assert stage.report_background_exception.call_count == 0 - - reconnect_op.complete(error=arbitrary_exception) - - assert stage.report_background_exception.call_count == 1 - assert stage.report_background_exception.call_args == mocker.call(arbitrary_exception) - - @pytest.mark.it("Changes the state to DISCONNECTED if completed with error") - def test_failure_state_change(self, stage, reconnect_op, arbitrary_exception): - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - - reconnect_op.complete(error=arbitrary_exception) - - assert stage.nucleus.connection_state is ConnectionState.DISCONNECTED - - @pytest.mark.it("Does not change the state if completed successfully") - def test_success_state_change(self, stage, reconnect_op): - assert stage.nucleus.connection_state is ConnectionState.CONNECTING - - # emulate state change from the ConnectedEvent firing - stage.nucleus.connection_state = ConnectionState.CONNECTED - # complete the op - reconnect_op.complete() - - assert stage.nucleus.connection_state is ConnectionState.CONNECTED - - @pytest.mark.it( - "Starts a new reconnect timer if the operation completed with a transient error" - ) - @pytest.mark.parametrize( - "error", - [ - pytest.param(pipeline_exceptions.OperationCancelled(), id="OperationCancelled"), - pytest.param(pipeline_exceptions.OperationTimeout(), id="OperationTimeout"), - pytest.param(pipeline_exceptions.OperationError(), id="OperationError"), - pytest.param(transport_exceptions.ConnectionFailedError(), id="ConnectionFailedError"), - pytest.param( - transport_exceptions.ConnectionDroppedError(), id="ConnectionDroppedError" - ), - pytest.param(transport_exceptions.TlsExchangeAuthError(), id="TlsExchangeAuthError"), - ], - ) - def test_transient_error_completion(self, mocker, stage, reconnect_op, mock_timer, error): - assert stage.reconnect_timer is None - assert mock_timer.call_count == 0 - - reconnect_op.complete(error=error) - - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call( - stage.nucleus.pipeline_configuration.connection_retry_interval, mocker.ANY - ) - assert stage.reconnect_timer is mock_timer.return_value - - @pytest.mark.it( - "Does not start a new reconnect timer if the operation completed with a non-transient (i.e. non-recoverable) error" - ) - def test_non_transient_error_completion( - self, stage, reconnect_op, mock_timer, arbitrary_exception - ): - assert stage.reconnect_timer is None - assert mock_timer.call_count == 0 - - reconnect_op.complete(error=arbitrary_exception) - - assert mock_timer.call_count == 0 - assert stage.reconnect_timer is None diff --git a/tests/unit/common/pipeline/test_pipeline_stages_http.py b/tests/unit/common/pipeline/test_pipeline_stages_http.py deleted file mode 100644 index 53a591454..000000000 --- a/tests/unit/common/pipeline/test_pipeline_stages_http.py +++ /dev/null @@ -1,291 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import sys -from azure.iot.device.common.pipeline import ( - pipeline_ops_base, - pipeline_ops_http, - pipeline_stages_http, -) -from tests.unit.common.pipeline.helpers import StageRunOpTestBase -from tests.unit.common.pipeline import pipeline_stage_test - - -this_module = sys.modules[__name__] -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - -################### -# COMMON FIXTURES # -################### - - -@pytest.fixture -def mock_transport(mocker): - return mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_http.HTTPTransport", autospec=True - ) - - -# Not a fixture, but used in parametrization -def fake_callback(): - pass - - -######################## -# HTTP TRANSPORT STAGE # -######################## - - -class HTTPTransportStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_http.HTTPTransportStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.hostname = "some.fake-host.name.com" - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class HTTPTransportInstantiationTests(HTTPTransportStageTestConfig): - @pytest.mark.it("Initializes 'transport' attribute as None") - def test_transport(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.transport is None - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_http.HTTPTransportStage, - stage_test_config_class=HTTPTransportStageTestConfig, - extended_stage_instantiation_test_class=HTTPTransportInstantiationTests, -) - - -@pytest.mark.describe("HTTPTransportStage - .run_op() -- Called with InitializePipelineOperation") -class TestHTTPTransportStageRunOpCalledWithInitializePipelineOperation( - HTTPTransportStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - return op - - @pytest.mark.it( - "Creates an HTTPTransport object and sets it as the 'transport' attribute of the stage (and on the pipeline root)" - ) - @pytest.mark.parametrize( - "gateway_hostname", - [ - pytest.param("fake.gateway.hostname.com", id="Using Gateway Hostname"), - pytest.param(None, id="Not using Gateway Hostname"), - ], - ) - def test_creates_transport(self, mocker, stage, op, mock_transport, gateway_hostname): - # Setup pipeline config - stage.nucleus.pipeline_configuration.gateway_hostname = gateway_hostname - - # NOTE: if more of this type of logic crops up, consider splitting this test up - if stage.nucleus.pipeline_configuration.gateway_hostname: - expected_hostname = stage.nucleus.pipeline_configuration.gateway_hostname - else: - expected_hostname = stage.nucleus.pipeline_configuration.hostname - - assert stage.transport is None - - stage.run_op(op) - - assert mock_transport.call_count == 1 - assert mock_transport.call_args == mocker.call( - hostname=expected_hostname, - server_verification_cert=stage.nucleus.pipeline_configuration.server_verification_cert, - x509_cert=stage.nucleus.pipeline_configuration.x509, - cipher=stage.nucleus.pipeline_configuration.cipher, - proxy_options=stage.nucleus.pipeline_configuration.proxy_options, - ) - assert stage.transport is mock_transport.return_value - - @pytest.mark.it("Completes the operation with success, upon successful execution") - def test_succeeds(self, mocker, stage, op, mock_transport): - assert not op.completed - stage.run_op(op) - assert op.completed - - -# NOTE: The HTTPTransport object is not instantiated upon instantiation of the HTTPTransportStage. -# It is only added once the InitializePipelineOperation runs. -# The lifecycle of the HTTPTransportStage is as follows: -# 1. Instantiate the stage -# 2. Configure the stage with an InitializePipelineOperation -# 3. Run any other desired operations. -# -# This is to say, no operation should be running before InitializePipelineOperation. -# Thus, for the following tests, we will assume that the HTTPTransport has already been created, -# and as such, the stage fixture used will have already have one. -class HTTPTransportStageTestConfigComplex(HTTPTransportStageTestConfig): - @pytest.fixture - def stage(self, mocker, request, cls_type, init_kwargs, nucleus, mock_transport): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - - # Set up the Transport on the stage - op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - stage.run_op(op) - - assert stage.transport is mock_transport.return_value - - return stage - - -@pytest.mark.describe( - "HTTPTransportStage - .run_op() -- Called with HTTPRequestAndResponseOperation" -) -class TestHTTPTransportStageRunOpCalledWithHTTPRequestAndResponseOperation( - HTTPTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_http.HTTPRequestAndResponseOperation( - method="SOME_METHOD", - path="fake/path", - headers={"fake_key": "fake_val"}, - body="fake_body", - query_params="arg1=val1;arg2=val2", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Sends an HTTP request via the HTTPTransport") - def test_http_request(self, mocker, stage, op): - stage.run_op(op) - - assert stage.transport.request.call_count == 1 - assert stage.transport.request.call_args == mocker.call( - method=op.method, - path=op.path, - # headers are tested in depth in the following two tests - headers=mocker.ANY, - body=op.body, - query_params=op.query_params, - callback=mocker.ANY, - ) - - @pytest.mark.it( - "Adds the SasToken in the request's 'Authorization' header if using SAS-based authentication" - ) - def test_headers_with_sas_auth(self, mocker, stage, op): - # A SasToken is set on the pipeline, but Authorization headers have not yet been set - assert stage.nucleus.pipeline_configuration.sastoken is not None - assert op.headers.get("Authorization") is None - - stage.run_op(op) - - # Need to get the headers sent to the transport, not provided by the op, due to a - # deep copy that occurs - headers = stage.transport.request.call_args[1]["headers"] - assert headers["Authorization"] == str(stage.nucleus.pipeline_configuration.sastoken) - - @pytest.mark.it( - "Does NOT add the 'Authorization' header to the request if NOT using SAS-based authentication" - ) - def test_headers_with_no_sas(self, mocker, stage, op): - # NO SasToken is set on the pipeline, and Authorization headers have not yet been set - stage.nucleus.pipeline_configuration.sastoken = None - assert op.headers.get("Authorization") is None - - stage.run_op(op) - - # Need to get the headers sent to the transport, not provided by the op, due to a - # deep copy that occurs - headers = stage.transport.request.call_args[1]["headers"] - assert headers.get("Authorization") is None - - @pytest.mark.it( - "Completes the operation unsuccessfully if there is a failure requesting via the HTTPTransport, using the error raised by the HTTPTransport" - ) - def test_fails_operation(self, mocker, stage, op, arbitrary_exception): - stage.transport.request.side_effect = arbitrary_exception - stage.run_op(op) - assert op.completed - assert op.error is arbitrary_exception - - @pytest.mark.it( - "Completes the operation successfully if the request invokes the provided callback without an error" - ) - def test_completes_callback(self, mocker, stage, op): - def mock_request_callback(method, path, headers, query_params, body, callback): - fake_response = { - "resp": "__fake_response__", - "status_code": "__fake_status_code__", - "reason": "__fake_reason__", - } - return callback(response=fake_response) - - # This is a way for us to mock the transport invoking the callback - stage.transport.request.side_effect = mock_request_callback - stage.run_op(op) - assert op.completed - - @pytest.mark.it( - "Adds a reason, status code, and response body to the op if request invokes the provided callback without an error" - ) - def test_formats_op_on_complete(self, mocker, stage, op): - def mock_request_callback(method, path, headers, query_params, body, callback): - fake_response = { - "resp": "__fake_response__", - "status_code": "__fake_status_code__", - "reason": "__fake_reason__", - } - return callback(response=fake_response) - - # This is a way for us to mock the transport invoking the callback - stage.transport.request.side_effect = mock_request_callback - stage.run_op(op) - assert op.reason == "__fake_reason__" - assert op.response_body == "__fake_response__" - assert op.status_code == "__fake_status_code__" - - @pytest.mark.it( - "Completes the operation with an error if the request invokes the provided callback with the same error" - ) - def test_completes_callback_with_error(self, mocker, stage, op, arbitrary_exception): - def mock_on_response_complete(method, path, headers, query_params, body, callback): - return callback(error=arbitrary_exception) - - stage.transport.request.side_effect = mock_on_response_complete - stage.run_op(op) - assert op.completed - assert op.error is arbitrary_exception - - -# NOTE: This is not something that should ever happen in correct program flow -# There should be no operations that make it to the HTTPTransportStage that are not handled by it -@pytest.mark.describe("HTTPTransportStage - .run_op() -- called with arbitrary other operation") -class TestHTTPTransportStageRunOpCalledWithArbitraryOperation( - HTTPTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) diff --git a/tests/unit/common/pipeline/test_pipeline_stages_mqtt.py b/tests/unit/common/pipeline/test_pipeline_stages_mqtt.py deleted file mode 100644 index ae65958fb..000000000 --- a/tests/unit/common/pipeline/test_pipeline_stages_mqtt.py +++ /dev/null @@ -1,1400 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import sys -import threading -from azure.iot.device.common import transport_exceptions, handle_exceptions -from azure.iot.device.common.pipeline import ( - pipeline_ops_base, - pipeline_ops_mqtt, - pipeline_events_base, - pipeline_events_mqtt, - pipeline_stages_mqtt, - pipeline_exceptions, -) -from tests.unit.common.pipeline.helpers import StageRunOpTestBase -from tests.unit.common.pipeline import pipeline_stage_test - -this_module = sys.modules[__name__] -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - -logging.getLogger("azure.iot.device.common").setLevel(level=logging.DEBUG) - -################### -# COMMON FIXTURES # -################### - - -@pytest.fixture -def mock_transport(mocker): - return mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True - ) - - -@pytest.fixture -def mock_timer(mocker): - return mocker.patch.object(threading, "Timer") - - -# Not a fixture, but used in parametrization -def fake_callback(op, error): - pass - - -######################## -# MQTT TRANSPORT STAGE # -######################## - - -class MQTTTransportStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_mqtt.MQTTTransportStage - - @pytest.fixture - def init_kwargs(self, mocker): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.hostname = "some.fake-host.name.com" - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class MQTTTransportInstantiationTests(MQTTTransportStageTestConfig): - @pytest.mark.it("Initializes 'transport' attribute as None") - def test_transport(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.transport is None - - @pytest.mark.it("Initializes with no pending connection operation") - def test_pending_op(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage._pending_connection_op is None - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_mqtt.MQTTTransportStage, - stage_test_config_class=MQTTTransportStageTestConfig, - extended_stage_instantiation_test_class=MQTTTransportInstantiationTests, -) - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with InitializePipelineOperation") -class TestMQTTTransportStageRunOpCalledWithInitializePipelineOperation( - MQTTTransportStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - # These values are patched onto the op in a previous stage - op.client_id = "fake_client_id" - op.username = "fake_username" - return op - - @pytest.mark.it( - "Creates an MQTTTransport object and sets it as the 'transport' attribute of the stage" - ) - @pytest.mark.parametrize( - "websockets", - [ - pytest.param(True, id="Pipeline configured for websockets"), - pytest.param(False, id="Pipeline NOT configured for websockets"), - ], - ) - @pytest.mark.parametrize( - "cipher", - [ - pytest.param("DHE-RSA-AES128-SHA", id="Pipeline configured for custom cipher"), - pytest.param( - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Pipeline configured for multiple custom ciphers", - ), - pytest.param("", id="Pipeline NOT configured for custom cipher(s)"), - ], - ) - @pytest.mark.parametrize( - "proxy_options", - [ - pytest.param("FAKE-PROXY", id="Proxy present"), - pytest.param(None, id="Proxy None"), - pytest.param("", id="Proxy Absent"), - ], - ) - @pytest.mark.parametrize( - "gateway_hostname", - [ - pytest.param("fake.gateway.hostname.com", id="Using Gateway Hostname"), - pytest.param(None, id="Not using Gateway Hostname"), - ], - ) - @pytest.mark.parametrize( - "keep_alive", - [ - pytest.param(900, id="Pipeline configured for custom keep alive"), - pytest.param(None, id="Pipeline NOT configured for custom keep alive"), - ], - ) - def test_creates_transport( - self, - mocker, - stage, - op, - mock_transport, - websockets, - cipher, - proxy_options, - gateway_hostname, - keep_alive, - ): - # Configure websockets & cipher & keep alive - stage.nucleus.pipeline_configuration.websockets = websockets - stage.nucleus.pipeline_configuration.cipher = cipher - stage.nucleus.pipeline_configuration.proxy_options = proxy_options - stage.nucleus.pipeline_configuration.gateway_hostname = gateway_hostname - stage.nucleus.pipeline_configuration.keep_alive = keep_alive - - # NOTE: if more of this type of logic crops up, consider splitting this test up - if stage.nucleus.pipeline_configuration.gateway_hostname: - expected_hostname = stage.nucleus.pipeline_configuration.gateway_hostname - else: - expected_hostname = stage.nucleus.pipeline_configuration.hostname - - assert stage.transport is None - - stage.run_op(op) - - assert mock_transport.call_count == 1 - assert mock_transport.call_args == mocker.call( - client_id=op.client_id, - hostname=expected_hostname, - username=op.username, - server_verification_cert=stage.nucleus.pipeline_configuration.server_verification_cert, - x509_cert=stage.nucleus.pipeline_configuration.x509, - websockets=websockets, - cipher=cipher, - proxy_options=proxy_options, - keep_alive=keep_alive, - ) - assert stage.transport is mock_transport.return_value - - @pytest.mark.it("Sets event handlers on the newly created MQTTTransport") - def test_sets_transport_handlers(self, mocker, stage, op, mock_transport): - stage.run_op(op) - - assert stage.transport.on_mqtt_disconnected_handler == stage._on_mqtt_disconnected - assert stage.transport.on_mqtt_connected_handler == stage._on_mqtt_connected - assert ( - stage.transport.on_mqtt_connection_failure_handler == stage._on_mqtt_connection_failure - ) - assert stage.transport.on_mqtt_message_received_handler == stage._on_mqtt_message_received - - @pytest.mark.it("Sets the stage's pending connection operation to None") - def test_pending_conn_op(self, mocker, stage, op, mock_transport): - # NOTE: The pending connection operation ALREADY should be None, but we set it to None - # again for safety here just in case. So this test is for an edge case. - stage._pending_connection_op = mocker.MagicMock() - stage.run_op(op) - assert stage._pending_connection_op is None - - @pytest.mark.it("Completes the operation with success, upon successful execution") - def test_succeeds(self, mocker, stage, op, mock_transport): - assert not op.completed - stage.run_op(op) - assert op.completed - - -# NOTE: The MQTTTransport object is not instantiated upon instantiation of the MQTTTransportStage. -# It is only added once the InitializePipelineOperation runs. -# The lifecycle of the MQTTTransportStage is as follows: -# 1. Instantiate the stage -# 2. Configure the stage with an InitializePipelineOperation -# 3. Run any other desired operations. -# -# This is to say, no operation should be running before InitializePipelineOperation. -# Thus, for the following tests, we will assume that the MQTTTransport has already been created, -# and as such, the stage fixture used will have already have one. -class MQTTTransportStageTestConfigComplex(MQTTTransportStageTestConfig): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus, mock_transport): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - - # Set up the Transport on the stage - op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - op.client_id = "fake_client_id" - op.username = "fake_username" - stage.run_op(op) - - assert stage.transport is mock_transport.return_value - - return stage - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with ShutdownPipelineOperation") -class TestMQTTTransportStageRunOpCalledWithShutdownPipelineOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ShutdownPipelineOperation(callback=mocker.MagicMock()) - - @pytest.mark.it("Performs a shutdown of the MQTTTransport") - def test_transport_shutdown(self, mocker, stage, op): - stage.run_op(op) - assert stage.transport.shutdown.call_count == 1 - assert stage.transport.shutdown.call_args == mocker.call() - - @pytest.mark.it( - "Completes the operation successfully if there is no error in executing the MQTTTransport shutdown" - ) - def test_no_error(self, stage, op): - stage.run_op(op) - assert op.completed - assert op.error is None - - @pytest.mark.it( - "Completes the operation unsuccessfully (with error) if there was an error in executing the MQTTTransport shutdown" - ) - def test_error_occurs(self, mocker, stage, op, arbitrary_exception): - stage.transport.shutdown.side_effect = arbitrary_exception - stage.run_op(op) - assert op.completed - assert op.error is arbitrary_exception - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with ConnectOperation") -class TestMQTTTransportStageRunOpCalledWithConnectOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - - @pytest.mark.it("Sets the operation as the stage's pending connection operation") - def test_sets_pending_operation(self, stage, op): - stage.run_op(op) - assert stage._pending_connection_op is op - - @pytest.mark.it("Cancels any already pending connection operation") - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param( - pipeline_ops_base.ConnectOperation(callback=fake_callback), - id="Pending ConnectOperation", - ), - pytest.param( - pipeline_ops_base.DisconnectOperation(callback=fake_callback), - id="Pending DisconnectOperation", - ), - ], - ) - def test_pending_operation_cancelled(self, mocker, stage, op, pending_connection_op): - # Set up a pending op - stage._pending_connection_op = pending_connection_op - assert not pending_connection_op.completed - - # Run the connect op - stage.run_op(op) - - # Operation has been completed, with an OperationCancelled exception set indicating early cancellation - assert pending_connection_op.completed - assert type(pending_connection_op.error) is pipeline_exceptions.OperationCancelled - - # New operation is now the pending operation - assert stage._pending_connection_op is op - - @pytest.mark.it("Starts the connection watchdog") - def test_starts_watchdog(self, mocker, stage, op, mock_timer): - stage.run_op(op) - - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call(60, mocker.ANY) - assert mock_timer.return_value.daemon is True - assert mock_timer.return_value.start.call_count == 1 - - @pytest.mark.it( - "Performs an MQTT connect via the MQTTTransport, using the PipelineNucleus' SasToken as a password, if using SAS-based authentication" - ) - def test_mqtt_connect_sastoken(self, mocker, stage, op): - assert stage.nucleus.pipeline_configuration.sastoken is not None - stage.run_op(op) - assert stage.transport.connect.call_count == 1 - assert stage.transport.connect.call_args == mocker.call( - password=str(stage.nucleus.pipeline_configuration.sastoken) - ) - - @pytest.mark.it( - "Performs an MQTT connect via the MQTTTransport, with no password, if NOT using SAS-based authentication" - ) - def test_mqtt_connect_no_sastoken(self, mocker, stage, op): - # no token - stage.nucleus.pipeline_configuration.sastoken = None - stage.run_op(op) - assert stage.transport.connect.call_count == 1 - assert stage.transport.connect.call_args == mocker.call(password=None) - - @pytest.mark.it( - "Completes the operation unsuccessfully if there is a failure connecting via the MQTTTransport, using the error raised by the MQTTTransport" - ) - def test_fails_operation(self, mocker, stage, op, arbitrary_exception): - stage.transport.connect.side_effect = arbitrary_exception - stage.run_op(op) - assert op.completed - assert op.error is arbitrary_exception - - @pytest.mark.it( - "Resets the stage's pending connection operation to None, if there is a failure connecting via the MQTTTransport" - ) - def test_clears_pending_op_on_failure(self, mocker, stage, op, arbitrary_exception): - stage.transport.connect.side_effect = arbitrary_exception - stage.run_op(op) - assert stage._pending_connection_op is None - - @pytest.mark.it( - "Leaves the watchdog running while waiting for the connect operation to complete" - ) - def test_leaves_watchdog_running(self, mocker, stage, op, arbitrary_exception, mock_timer): - stage.run_op(op) - assert mock_timer.return_value.cancel.call_count == 0 - assert op.watchdog_timer is mock_timer.return_value - - @pytest.mark.it( - "Cancels the connection watchdog if the MQTTTransport connect operation raises an exception" - ) - def test_cancels_watchdog(self, mocker, stage, op, arbitrary_exception, mock_timer): - stage.transport.connect.side_effect = arbitrary_exception - stage.run_op(op) - assert mock_timer.return_value.cancel.call_count == 1 - assert op.watchdog_timer is None - - -@pytest.mark.describe( - "MQTTTransportStage - .run_op() -- Called with ReauthorizeConnectionOperation" -) -class TestMQTTTransportStageRunOpCalledWithReauthorizeConnectionOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) - - @pytest.mark.it( - "Spawns a new DisconnectOperation (configured as a soft disconnect) and runs it on the stage" - ) - def test_disconnect(self, mocker, stage, op): - original_run_op = stage.run_op - mock_run_op = mocker.MagicMock() - stage.run_op = mock_run_op - - original_run_op(op) - - assert mock_run_op.call_count == 1 - disconnect_op = mock_run_op.call_args[0][0] - assert isinstance(disconnect_op, pipeline_ops_base.DisconnectOperation) - assert disconnect_op.hard is False - - assert not op.completed - - @pytest.mark.it( - "Spawns a new ConnectOperation and runs it on the stage upon completion of the DisconnectOperation" - ) - @pytest.mark.parametrize( - "successful_disconnect", - [ - pytest.param(True, id="Disconnect Completed Successfully"), - pytest.param(False, id="Disconnect Completed with Error"), - ], - ) - def test_connect(self, mocker, stage, op, successful_disconnect, arbitrary_exception): - original_run_op = stage.run_op - mock_run_op = mocker.MagicMock() - stage.run_op = mock_run_op - - original_run_op(op) - - assert mock_run_op.call_count == 1 - disconnect_op = mock_run_op.call_args[0][0] - assert isinstance(disconnect_op, pipeline_ops_base.DisconnectOperation) - - if successful_disconnect: - error = None - else: - error = arbitrary_exception - - disconnect_op.complete(error=error) - - assert not op.completed - - assert mock_run_op.call_count == 2 - connect_op = mock_run_op.call_args[0][0] - assert isinstance(connect_op, pipeline_ops_base.ConnectOperation) - - assert not op.completed - - @pytest.mark.it( - "Completes the original ReauthorizeConnectionOperation upon completion of the ConnectOperation" - ) - @pytest.mark.parametrize( - "successful_connect", - [ - pytest.param(True, id="Connect Completed Successfully"), - pytest.param(False, id="Connect Completed with Error"), - ], - ) - def test_completion(self, mocker, stage, op, successful_connect, arbitrary_exception): - original_run_op = stage.run_op - mock_run_op = mocker.MagicMock() - stage.run_op = mock_run_op - - original_run_op(op) - - assert mock_run_op.call_count == 1 - disconnect_op = mock_run_op.call_args[0][0] - assert isinstance(disconnect_op, pipeline_ops_base.DisconnectOperation) - - disconnect_op.complete() - - assert not op.completed - - assert mock_run_op.call_count == 2 - connect_op = mock_run_op.call_args[0][0] - assert isinstance(connect_op, pipeline_ops_base.ConnectOperation) - - assert not op.completed - - if successful_connect: - error = None - else: - error = arbitrary_exception - - connect_op.complete(error=error) - - assert op.completed - assert op.error is error - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with DisconnectOperation") -class TestMQTTTransportStageRunOpCalledWithDisconnectOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - - @pytest.mark.it("Sets the operation as the stage's pending connection operation") - def test_sets_pending_operation(self, stage, op): - stage.run_op(op) - assert stage._pending_connection_op is op - - @pytest.mark.it("Cancels any already pending connection operation") - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param( - pipeline_ops_base.ConnectOperation(callback=fake_callback), - id="Pending ConnectOperation", - ), - pytest.param( - pipeline_ops_base.DisconnectOperation(callback=fake_callback), - id="Pending DisconnectOperation", - ), - ], - ) - def test_pending_operation_cancelled(self, mocker, stage, op, pending_connection_op): - # Set up a pending op - stage._pending_connection_op = pending_connection_op - assert not pending_connection_op.completed - - # Run the connect op - stage.run_op(op) - - # Operation has been completed, with an OperationCancelled exception set indicating early cancellation - assert pending_connection_op.completed - assert type(pending_connection_op.error) is pipeline_exceptions.OperationCancelled - - # New operation is now the pending operation - assert stage._pending_connection_op is op - - @pytest.mark.it( - "Performs an MQTT disconnect via the MQTTTransport, using the 'clear_inflight' option only if the operation is configured for a hard disconnect" - ) - def test_mqtt_connect(self, mocker, stage, op): - # Hard disconnect - assert op.hard is True - stage.run_op(op) - assert stage.transport.disconnect.call_count == 1 - assert stage.transport.disconnect.call_args == mocker.call(clear_inflight=True) - - stage.transport.disconnect.reset_mock() - - # Soft disconnect - op.hard = False - stage.run_op(op) - assert stage.transport.disconnect.call_count == 1 - assert stage.transport.disconnect.call_args == mocker.call(clear_inflight=False) - - @pytest.mark.it( - "Completes the operation unsuccessfully if there is a failure disconnecting via the MQTTTransport, using the error raised by the MQTTTransport" - ) - def test_fails_operation(self, mocker, stage, op, arbitrary_exception): - stage.transport.disconnect.side_effect = arbitrary_exception - stage.run_op(op) - assert op.completed - assert op.error is arbitrary_exception - - @pytest.mark.it( - "Resets the stage's pending connection operation to None, if there is a failure disconnecting via the MQTTTransport" - ) - def test_clears_pending_op_on_failure(self, mocker, stage, op, arbitrary_exception): - stage.transport.disconnect.side_effect = arbitrary_exception - stage.run_op(op) - assert stage._pending_connection_op is None - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with MQTTPublishOperation") -class TestMQTTTransportStageRunOpCalledWithMQTTPublishOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_mqtt.MQTTPublishOperation( - topic="fake_topic", payload="fake_payload", callback=mocker.MagicMock() - ) - - @pytest.mark.it("Performs an MQTT publish via the MQTTTransport") - def test_mqtt_publish(self, mocker, stage, op): - stage.run_op(op) - assert stage.transport.publish.call_count == 1 - assert stage.transport.publish.call_args == mocker.call( - topic=op.topic, payload=op.payload, callback=mocker.ANY - ) - - @pytest.mark.it( - "Successfully completes the operation, upon successful completion of the MQTT publish by the MQTTTransport" - ) - def test_complete(self, mocker, stage, op): - # Begin publish - stage.run_op(op) - - assert not op.completed - - # Trigger publish completion - stage.transport.publish.call_args[1]["callback"]() - - assert op.completed - assert op.error is None - - @pytest.mark.it( - "Completes the operation with an OperationCancelled error upon cancellation of the MQTT unsubscribe by the MQTTTransport" - ) - def test_complete_with_cancel(self, mocker, stage, op): - # Begin publish - stage.run_op(op) - - assert not op.completed - - # Trigger publish cancellation - stage.transport.publish.call_args[1]["callback"](cancelled=True) - - assert op.completed - assert isinstance(op.error, pipeline_exceptions.OperationCancelled) - - @pytest.mark.it( - "Completes the operation using the exception that was raised, if an exception was raised from the MQTTTransport" - ) - def test_publish_error(self, stage, op, arbitrary_exception): - stage.transport.publish.side_effect = arbitrary_exception - - stage.run_op(op) - - assert op.completed - assert op.error is arbitrary_exception - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with MQTTSubscribeOperation") -class TestMQTTTransportStageRunOpCalledWithMQTTSubscribeOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_mqtt.MQTTSubscribeOperation( - topic="fake_topic", callback=mocker.MagicMock() - ) - - @pytest.mark.it("Performs an MQTT subscribe via the MQTTTransport") - def test_mqtt_publish(self, mocker, stage, op): - stage.run_op(op) - assert stage.transport.subscribe.call_count == 1 - assert stage.transport.subscribe.call_args == mocker.call( - topic=op.topic, callback=mocker.ANY - ) - - @pytest.mark.it( - "Successfully completes the operation, upon successful completion of the MQTT subscribe by the MQTTTransport" - ) - def test_complete(self, mocker, stage, op): - # Begin subscribe - stage.run_op(op) - - assert not op.completed - - # Trigger subscribe completion - stage.transport.subscribe.call_args[1]["callback"]() - - assert op.completed - assert op.error is None - - @pytest.mark.it( - "Completes the operation with an OperationCancelled error upon cancellation of the MQTT unsubscribe by the MQTTTransport" - ) - def test_complete_with_cancel(self, mocker, stage, op): - # Begin unsubscribe - stage.run_op(op) - - assert not op.completed - - # Trigger subscribe cancellation - stage.transport.subscribe.call_args[1]["callback"](cancelled=True) - - assert op.completed - assert isinstance(op.error, pipeline_exceptions.OperationCancelled) - - @pytest.mark.it( - "Completes the operation using the exception that was raised, if an exception was raised from the MQTTTransport" - ) - def test_subscribe_error(self, stage, op, arbitrary_exception): - stage.transport.subscribe.side_effect = arbitrary_exception - - stage.run_op(op) - - assert op.completed - assert op.error is arbitrary_exception - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with MQTTUnsubscribeOperation") -class TestMQTTTransportStageRunOpCalledWithMQTTUnsubscribeOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_mqtt.MQTTUnsubscribeOperation( - topic="fake_topic", callback=mocker.MagicMock() - ) - - @pytest.mark.it("Performs an MQTT unsubscribe via the MQTTTransport") - def test_mqtt_publish(self, mocker, stage, op): - stage.run_op(op) - assert stage.transport.unsubscribe.call_count == 1 - assert stage.transport.unsubscribe.call_args == mocker.call( - topic=op.topic, callback=mocker.ANY - ) - - @pytest.mark.it( - "Successfully completes the operation upon successful completion of the MQTT unsubscribe by the MQTTTransport" - ) - def test_complete(self, mocker, stage, op): - # Begin unsubscribe - stage.run_op(op) - - assert not op.completed - - # Trigger unsubscribe completion - stage.transport.unsubscribe.call_args[1]["callback"]() - - assert op.completed - assert op.error is None - - @pytest.mark.it( - "Completes the operation with an OperationCancelled error upon cancellation of the MQTT unsubscribe by the MQTTTransport" - ) - def test_complete_with_cancel(self, mocker, stage, op): - # Begin unsubscribe - stage.run_op(op) - - assert not op.completed - - # Trigger unsubscribe cancellation - stage.transport.unsubscribe.call_args[1]["callback"](cancelled=True) - - assert op.completed - assert isinstance(op.error, pipeline_exceptions.OperationCancelled) - - @pytest.mark.it( - "Completes the operation using the exception that was raised, if an exception was raised from the MQTTTransport" - ) - def test_publish_error(self, stage, op, arbitrary_exception): - stage.transport.unsubscribe.side_effect = arbitrary_exception - - stage.run_op(op) - - assert op.completed - assert op.error is arbitrary_exception - - -# NOTE: This is not something that should ever happen in correct program flow -# There should be no operations that make it to the MQTTTransportStage that are not handled by it -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with arbitrary other operation") -class TestMQTTTransportStageRunOpCalledWithArbitraryOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe("MQTTTransportStage - OCCURRENCE: MQTT message received") -class TestMQTTTransportStageProtocolClientEvents(MQTTTransportStageTestConfigComplex): - @pytest.mark.it("Sends an IncomingMQTTMessageEvent event up the pipeline") - def test_incoming_message_handler(self, stage, mocker): - # Trigger MQTT message received - stage.transport.on_mqtt_message_received_handler(topic="fake_topic", payload="fake_payload") - - assert stage.send_event_up.call_count == 1 - event = stage.send_event_up.call_args[0][0] - assert isinstance(event, pipeline_events_mqtt.IncomingMQTTMessageEvent) - - @pytest.mark.it("Passes topic and payload as part of the IncomingMQTTMessageEvent") - def test_verify_incoming_message_attributes(self, stage, mocker): - fake_topic = "fake_topic" - fake_payload = "fake_payload" - - # Trigger MQTT message received - stage.transport.on_mqtt_message_received_handler(topic=fake_topic, payload=fake_payload) - - event = stage.send_event_up.call_args[0][0] - assert event.payload == fake_payload - assert event.topic == fake_topic - - -@pytest.mark.describe("MQTTTransportStage - OCCURRENCE: MQTT connected") -class TestMQTTTransportStageOnConnected(MQTTTransportStageTestConfigComplex): - @pytest.mark.it("Sends a ConnectedEvent up the pipeline") - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param(None, id="No pending operation"), - pytest.param( - pipeline_ops_base.ConnectOperation(callback=fake_callback), - id="Pending ConnectOperation", - ), - pytest.param( - pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), - id="Pending ReauthorizeConnectionOperation", - ), - pytest.param( - pipeline_ops_base.DisconnectOperation(callback=fake_callback), - id="Pending DisconnectOperation", - ), - ], - ) - def test_sends_event_up(self, stage, pending_connection_op): - stage._pending_connection_op = pending_connection_op - # Trigger connect completion - stage.transport.on_mqtt_connected_handler() - - assert stage.send_event_up.call_count == 1 - connect_event = stage.send_event_up.call_args[0][0] - assert isinstance(connect_event, pipeline_events_base.ConnectedEvent) - - @pytest.mark.it("Completes a pending ConnectOperation successfully") - def test_completes_pending_connect_op(self, mocker, stage): - # Set a pending connect operation - op = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - assert not op.completed - assert stage._pending_connection_op is op - - # Trigger connect completion - stage.transport.on_mqtt_connected_handler() - - # Connect operation completed successfully - assert op.completed - assert op.error is None - assert stage._pending_connection_op is None - - @pytest.mark.it( - "Does not complete a pending DisconnectOperation when the transport connected event fires" - ) - def test_does_not_complete_pending_disconnect_op(self, mocker, stage): - # Set a pending disconnect operation - op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - assert not op.completed - assert stage._pending_connection_op is op - - # Trigger connect completion - stage.transport.on_mqtt_connected_handler() - - # Disconnect operation was NOT completed - assert not op.completed - assert stage._pending_connection_op is op - - @pytest.mark.it( - "Cancels the connection watchdog if the pending operation is a ConnectOperation" - ) - def test_cancels_watchdog_on_pending_connect(self, mocker, stage, mock_timer): - # Set a pending connect operation - op = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - - # assert watchdog is running - assert op.watchdog_timer is mock_timer.return_value - assert op.watchdog_timer.start.call_count == 1 - - # Trigger connect completion - stage.transport.on_mqtt_connected_handler() - - # assert watchdog was cancelled - assert op.watchdog_timer is None - assert mock_timer.return_value.cancel.call_count == 1 - - @pytest.mark.it( - "Does not cancels the connection watchdog if the pending operation is DisconnectOperation because there is no connection watchdog" - ) - def test_does_not_cancel_watchdog_on_pending_disconnect(self, mocker, stage, mock_timer): - # Set a pending disconnect operation - op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - - # assert no timers are running - assert mock_timer.return_value.start.call_count == 0 - - # Trigger connect completion - stage.transport.on_mqtt_connected_handler() - - # assert no timers are still running - assert mock_timer.return_value.start.call_count == 0 - assert mock_timer.return_value.cancel.call_count == 0 - - -@pytest.mark.describe("MQTTTransportStage - OCCURRENCE: MQTT connection failure") -class TestMQTTTransportStageOnConnectionFailure(MQTTTransportStageTestConfigComplex): - @pytest.mark.it("Does not send any events up the pipeline") - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param(None, id="No pending operation"), - pytest.param( - pipeline_ops_base.ConnectOperation(callback=fake_callback), - id="Pending ConnectOperation", - ), - pytest.param( - pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), - id="Pending ReauthorizeConnectionOperation", - ), - pytest.param( - pipeline_ops_base.DisconnectOperation(callback=fake_callback), - id="Pending DisconnectOperation", - ), - ], - ) - def test_does_not_send_event(self, mocker, stage, pending_connection_op, arbitrary_exception): - stage._pending_connection_op = pending_connection_op - - # Trigger connection failure with an arbitrary cause - stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) - - assert stage.send_event_up.call_count == 0 - - @pytest.mark.it( - "Completes a pending ConnectOperation unsuccessfully with the cause of connection failure as the error" - ) - def test_fails_pending_connect_op(self, mocker, stage, arbitrary_exception): - # Create a pending ConnectOperation - op = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - assert not op.completed - assert stage._pending_connection_op is op - - # Trigger connection failure with an arbitrary cause - stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) - - assert op.completed - assert op.error is arbitrary_exception - assert stage._pending_connection_op is None - - @pytest.mark.it("Ignores a pending DisconnectOperation, and does not complete it") - def test_ignores_pending_disconnect_op(self, mocker, stage, arbitrary_exception): - # Create a pending DisconnectOperation - op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - assert not op.completed - assert stage._pending_connection_op is op - - # Trigger connection failure with an arbitrary cause - stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) - - # Assert nothing changed about the operation - assert not op.completed - assert stage._pending_connection_op is op - - @pytest.mark.it( - "Triggers the swallowed exception handler (with error cause) when the connection failure is unexpected" - ) - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param(None, id="No pending operation"), - pytest.param( - pipeline_ops_base.DisconnectOperation(callback=fake_callback), - id="Pending DisconnectOperation", - ), - ], - ) - def test_unexpected_connection_failure( - self, mocker, stage, arbitrary_exception, pending_connection_op - ): - # A connection failure is unexpected if there is not a pending Connect operation - # i.e. "Why did we get a connection failure? We weren't even trying to connect!" - mock_handler = mocker.patch.object(handle_exceptions, "swallow_unraised_exception") - stage._pending_connection_operation = pending_connection_op - - # Trigger connection failure with arbitrary cause - stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) - - # swallow exception handler has been called - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call( - arbitrary_exception, log_msg=mocker.ANY, log_lvl="info" - ) - - @pytest.mark.it( - "Cancels the connection watchdog if the pending operation is a ConnectOperation" - ) - def test_cancels_watchdog_on_pending_connect( - self, mocker, stage, mock_timer, arbitrary_exception - ): - # Set a pending connect operation - op = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - - # assert watchdog is running - assert op.watchdog_timer is mock_timer.return_value - assert op.watchdog_timer.start.call_count == 1 - - # Trigger connection failure with arbitrary cause - stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) - - # assert watchdog was cancelled - assert op.watchdog_timer is None - assert mock_timer.return_value.cancel.call_count == 1 - - @pytest.mark.it( - "Does not cancels the connection watchdog if the pending operation is DisconnectOperation" - ) - def test_does_not_cancel_watchdog_on_pending_disconnect( - self, mocker, stage, mock_timer, arbitrary_exception - ): - # Set a pending disconnect operation - op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - - # assert no timers are running - assert mock_timer.return_value.start.call_count == 0 - - # Trigger connection failure with arbitrary cause - stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) - - # assert no timers are still running - assert mock_timer.return_value.start.call_count == 0 - assert mock_timer.return_value.cancel.call_count == 0 - - -@pytest.mark.describe("MQTTTransportStage - OCCURRENCE: MQTT disconnected (Expected)") -class TestMQTTTransportStageOnDisconnectedExpected(MQTTTransportStageTestConfigComplex): - @pytest.fixture(params=[False, True], ids=["No error cause", "With error cause"]) - def cause(self, request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - @pytest.fixture - def pending_connection_op(self): - return pipeline_ops_base.DisconnectOperation(callback=fake_callback) - - @pytest.mark.it("Sends a DisconnectedEvent up the pipeline") - def test_disconnect_event_sent(self, stage, cause, pending_connection_op): - stage._pending_connection_op = pending_connection_op - assert stage.send_event_up.call_count == 0 - - # Trigger disconnect - stage.transport.on_mqtt_disconnected_handler(cause) - - assert stage.send_event_up.call_count == 1 - event = stage.send_event_up.call_args[0][0] - assert isinstance(event, pipeline_events_base.DisconnectedEvent) - - @pytest.mark.it("Swallows the exception that caused the disconnect if the cause is specified") - def test_error_swallowed(self, mocker, stage, arbitrary_exception, pending_connection_op): - mock_swallow = mocker.patch.object(handle_exceptions, "swallow_unraised_exception") - stage._pending_connection_op = pending_connection_op - - # Trigger disconnect with arbitrary cause - stage.transport.on_mqtt_disconnected_handler(arbitrary_exception) - - # Exception swallower was called - assert mock_swallow.call_count == 1 - assert mock_swallow.call_args == mocker.call(arbitrary_exception, log_msg=mocker.ANY) - - @pytest.mark.it( - "Completes the pending DisconnectOperation successfully and removes its pending status" - ) - def test_disconnect_op_completed(self, mocker, stage, cause, pending_connection_op): - stage._pending_connection_op = pending_connection_op - assert not pending_connection_op.completed - assert pending_connection_op.error is None - - # Trigger disconnect - stage.transport.on_mqtt_disconnected_handler(cause) - - assert stage._pending_connection_op is None - assert pending_connection_op.completed - assert pending_connection_op.error is None - - -@pytest.mark.describe( - "MQTTTransportStage - OCCURRENCE: MQTT disconnected (Unexpected - pending ConnectionOperation)" -) -class TestMQTTTransportStageOnDisconnectedUnexpectedWithPendingConnectOp( - MQTTTransportStageTestConfigComplex -): - @pytest.fixture(params=[False, True], ids=["No error cause", "With error cause"]) - def cause(self, request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - @pytest.fixture - def pending_connection_op(self): - return pipeline_ops_base.ConnectOperation(callback=fake_callback) - - @pytest.mark.it("Sends a DisconnectedEvent up the pipeline") - def test_disconnect_event_sent(self, stage, cause, pending_connection_op): - stage._pending_connection_op = pending_connection_op - assert stage.send_event_up.call_count == 0 - - # Trigger disconnect - stage.transport.on_mqtt_disconnected_handler(cause) - - assert stage.send_event_up.call_count == 1 - event = stage.send_event_up.call_args[0][0] - assert isinstance(event, pipeline_events_base.DisconnectedEvent) - - @pytest.mark.it( - "Completes the pending ConnectOperation unsuccessfully with the cause of the disconnection set as the error, and removes its pending status, if the cause is specified" - ) - def test_op_completed_with_cause(self, stage, arbitrary_exception, pending_connection_op): - stage._pending_connection_op = pending_connection_op - assert not pending_connection_op.completed - assert pending_connection_op.error is None - - # Trigger disconnect with arbitrary cause - stage.transport.on_mqtt_disconnected_handler(arbitrary_exception) - - assert stage._pending_connection_op is None - assert pending_connection_op.completed - assert pending_connection_op.error is arbitrary_exception - - @pytest.mark.it( - "Completes the pending ConnectOperation unsuccessfully with a ConnectionDroppedError, and removes its pending status, if no cause is provided for the disconnection" - ) - def test_op_completed_no_cause(self, stage, pending_connection_op): - stage._pending_connection_op = pending_connection_op - assert not pending_connection_op.completed - assert pending_connection_op.error is None - - # Trigger disconnect with no cause - stage.transport.on_mqtt_disconnected_handler() - - assert stage._pending_connection_op is None - assert pending_connection_op.completed - assert isinstance(pending_connection_op.error, transport_exceptions.ConnectionDroppedError) - - @pytest.mark.it("Cancels the connection watchdog") - def test_cancels_watchdog(self, mocker, stage, mock_timer, cause, pending_connection_op): - # Set a pending connect operation - stage.run_op(pending_connection_op) - - # assert watchdog is running - assert pending_connection_op.watchdog_timer is mock_timer.return_value - assert pending_connection_op.watchdog_timer.start.call_count == 1 - - # Trigger disconnect - stage.transport.on_mqtt_disconnected_handler(cause) - - # assert watchdog was cancelled - assert pending_connection_op.watchdog_timer is None - assert mock_timer.return_value.cancel.call_count == 1 - - -@pytest.mark.describe( - "MQTTTransportStage - OCCURRENCE: MQTT disconnected (Unexpected - no pending operation)" -) -class TestMQTTTransportStageOnDisconnectedUnexpectedNoPendingConnectionOp( - MQTTTransportStageTestConfigComplex -): - @pytest.fixture(params=[False, True], ids=["No error cause", "With error cause"]) - def cause(self, request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - @pytest.mark.it( - "Cancels all in-flight operations in the transport, if connection retry has been disabled" - ) - def test_inflight_no_retry(self, mocker, stage, cause): - stage.transport._op_manager = mocker.MagicMock() - mock_cancel = stage.transport._op_manager.cancel_all_operations - stage.nucleus.pipeline_configuration.connection_retry = False - assert stage._pending_connection_op is None - assert mock_cancel.call_count == 0 - - # Trigger disconnect - stage.transport.on_mqtt_disconnected_handler(cause) - - assert mock_cancel.call_count == 1 - assert mock_cancel.call_args == mocker.call() - - @pytest.mark.it( - "Does not cancel any in-flight operations in the transport if connection retry has been enabled" - ) - def test_inflight_unexpected_with_retry(self, mocker, stage, cause): - stage.transport._op_manager = mocker.MagicMock() - mock_cancel = stage.transport._op_manager.cancel_all_operations - stage.nucleus.pipeline_configuration.connection_retry = True - assert stage._pending_connection_op is None - assert mock_cancel.call_count == 0 - - # Trigger disconnect - stage.transport.on_mqtt_disconnected_handler(cause) - - assert mock_cancel.call_count == 0 - - @pytest.mark.it("Raises a ConnectionDroppedError as a background exception") - def test_background_exception_raised(self, stage, cause): - assert stage._pending_connection_op is None - assert stage.report_background_exception.call_count == 0 - - # Trigger disconnect - stage.transport.on_mqtt_disconnected_handler(cause) - - assert stage.report_background_exception.call_count == 1 - background_exception = stage.report_background_exception.call_args[0][0] - assert isinstance(background_exception, transport_exceptions.ConnectionDroppedError) - assert background_exception.__cause__ is cause - - -disconnect_can_raise = [ - "disconnect_raises", - [ - pytest.param(True, id="mqtt_transport.disconnect raises an exception"), - pytest.param(False, id="mqtt_transport.disconnect does not raises an exception"), - ], -] - - -@pytest.mark.describe("MQTTTransportStage - OCCURRENCE: Connection watchdog expired") -class TestMQTTTransportStageWatchdogExpired(MQTTTransportStageTestConfigComplex): - @pytest.fixture(params=[pipeline_ops_base.ConnectOperation], ids=["Pending ConnectOperation"]) - def pending_op(self, request, mocker): - return request.param(callback=mocker.MagicMock()) - - @pytest.mark.it( - "Performs an MQTT disconnect via the MQTTTransport if the op that started the watchdog is still pending" - ) - def test_calls_disconnect(self, mocker, stage, pending_op, mock_timer): - stage.run_op(pending_op) - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert stage.transport.disconnect.call_count == 1 - - @pytest.mark.it( - "Does not perform an MQTT disconnect via the MQTTTransport if the op that started the watchdog is no longer pending" - ) - def test_does_not_call_disconnect_if_no_longer_pending( - self, mocker, stage, pending_op, mock_timer - ): - stage.run_op(pending_op) - stage._pending_connection_op = None - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert stage.transport.disconnect.call_count == 0 - - @pytest.mark.parametrize(*disconnect_can_raise) - @pytest.mark.it( - "Completes the op that started the watchdog with an OperationTimeout exception if that op is still pending" - ) - def test_completes_with_operation_cancelled( - self, mocker, stage, pending_op, mock_timer, disconnect_raises, arbitrary_exception - ): - if disconnect_raises: - stage.transport.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) - - callback = pending_op.callback_stack[0] - - stage.run_op(pending_op) - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert callback.call_count == 1 - assert isinstance(callback.call_args[1]["error"], pipeline_exceptions.OperationTimeout) - - @pytest.mark.parametrize(*disconnect_can_raise) - @pytest.mark.it( - "Does not complete the op that started the watchdog with an OperationCancelled error if that op is no longer pending" - ) - def test_does_not_complete_op_if_no_longer_pending( - self, mocker, stage, pending_op, mock_timer, disconnect_raises, arbitrary_exception - ): - if disconnect_raises: - stage.transport.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) - - callback = pending_op.callback_stack[0] - - stage.run_op(pending_op) - stage._pending_connection_op = None - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert callback.call_count == 0 - - @pytest.mark.parametrize(*disconnect_can_raise) - @pytest.mark.it( - "Sends a DisconnectedEvent if the op that started the watchdog is still pending and the pipeline is connected" - ) - def test_sends_disconnected_event_if_still_pending_and_connected( - self, - mocker, - stage, - pending_op, - mock_timer, - disconnect_raises, - arbitrary_exception, - pipeline_connected_mock, - ): - if disconnect_raises: - stage.transport.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) - - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - stage.run_op(pending_op) - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert stage.send_event_up.call_count == 1 - assert isinstance( - stage.send_event_up.call_args[0][0], pipeline_events_base.DisconnectedEvent - ) - - @pytest.mark.parametrize(*disconnect_can_raise) - @pytest.mark.it( - "Does not send a DisconnectedEvent if the op that started the watchdog is still pending and the pipeline is not connected" - ) - def test_does_not_send_disconnected_event_if_still_pending_and_not_connected( - self, - mocker, - stage, - pending_op, - mock_timer, - disconnect_raises, - arbitrary_exception, - pipeline_connected_mock, - ): - if disconnect_raises: - stage.transport.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) - - pipeline_connected_mock.return_value = False - assert not stage.nucleus.connected - stage.run_op(pending_op) - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert stage.send_event_up.call_count == 0 - - @pytest.mark.parametrize(*disconnect_can_raise) - @pytest.mark.it( - "Does not send a DisconnectedEvent if the op that started the watchdog is no longer pending and the pipeline is connected" - ) - def test_does_not_send_disconnected_event_if_no_longer_pending_and_connected( - self, - mocker, - stage, - pending_op, - mock_timer, - disconnect_raises, - arbitrary_exception, - pipeline_connected_mock, - ): - if disconnect_raises: - stage.transport.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) - - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - stage.run_op(pending_op) - stage._pending_connection_op = None - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert stage.send_event_up.call_count == 0 - - @pytest.mark.parametrize(*disconnect_can_raise) - @pytest.mark.it( - "Does not send a DisconnectedEvent if the op that started the watchdog is no longer pending and the pipeline connected flag is False" - ) - def test_does_not_send_disconnected_event_if_no_longer_pending_and_not_connected( - self, - mocker, - stage, - pending_op, - mock_timer, - disconnect_raises, - arbitrary_exception, - pipeline_connected_mock, - ): - if disconnect_raises: - stage.transport.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) - - pipeline_connected_mock.return_value = True - assert stage.nucleus.connected - stage.run_op(pending_op) - stage._pending_connection_op = None - - watchdog_expiration = mock_timer.call_args[0][1] - watchdog_expiration() - - assert stage.send_event_up.call_count == 0 diff --git a/tests/unit/common/test_alarm.py b/tests/unit/common/test_alarm.py deleted file mode 100644 index 4d7c16a1e..000000000 --- a/tests/unit/common/test_alarm.py +++ /dev/null @@ -1,83 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import logging -import time -from azure.iot.device.common.alarm import Alarm - -logging.basicConfig(level=logging.DEBUG) - -# NOTE: A fundamental aspect of the Alarm class that makes it different from Timer (beyond the -# input format) is that sleeping the system will not throw off the timekeeping. However, there -# isn't a great way to test that in unit tests, because to do so would involve a system sleep. -# So note well that these tests would pass the same if using a Timer implementation, because the -# thing that makes Alarms unique isn't tested here. - - -@pytest.mark.describe("Alarm") -class TestAlarm(object): - @pytest.fixture - def desired_function(self, mocker): - return mocker.MagicMock() - - @pytest.mark.it( - "Invokes the given function with the given args and kwargs at the given alarm time once started" - ) - @pytest.mark.parametrize( - "args", [pytest.param(["arg1", "arg2"], id="W/ args"), pytest.param([], id="No args")] - ) - @pytest.mark.parametrize( - "kwargs", - [ - pytest.param({"kwarg1": "value1", "kwarg2": "value2"}, id="W/ kwargs"), - pytest.param({}, id="No kwargs"), - ], - ) - def test_fn_called_w_args(self, mocker, desired_function, args, kwargs): - alarm_time = time.time() + 2 # call fn in 2 seconds - a = Alarm(alarm_time=alarm_time, function=desired_function, args=args, kwargs=kwargs) - a.start() - - assert desired_function.call_count == 0 - time.sleep(1) # hasn't been 2 seconds yet - assert desired_function.call_count == 0 - time.sleep(1.1) # it has now been just over 2 seconds, so the call HAS been made - assert desired_function.call_count == 1 - assert desired_function.call_args == mocker.call(*args, **kwargs) - - @pytest.mark.it("Invokes the function with no args or kwargs by default if none are provided") - def test_fn_called_no_args(self, mocker, desired_function): - alarm_time = time.time() + 1 # call fn in 1 seconds - a = Alarm(alarm_time=alarm_time, function=desired_function) - a.start() - - assert desired_function.call_count == 0 - time.sleep(1.1) # it has now been just over 1 second, so the call HAS been made - assert desired_function.call_count == 1 - desired_function.call_args == mocker.call() - - @pytest.mark.it("Invokes the function immediately if the given alarm time is in the past") - def test_alarm_already_expired(self, mocker, desired_function): - alarm_time = time.time() - 1 - a = Alarm(alarm_time=alarm_time, function=desired_function) - a.start() - - assert desired_function.call_count == 1 - - @pytest.mark.it( - "Does not invoke the given function at the given alarm time if the alarm was cancelled before the given alarm time" - ) - def test_cancel_alarm(self, mocker, desired_function): - alarm_time = time.time() + 2 # call fn in 2 seconds - a = Alarm(alarm_time=alarm_time, function=desired_function) - a.start() - - assert desired_function.call_count == 0 - time.sleep(1) # hasn't been 2 seconds yet - assert desired_function.call_count == 0 - a.cancel() # cancel the alarm - time.sleep(1.5) # it has now been more than 2 seconds - assert desired_function.call_count == 0 # still not called diff --git a/tests/unit/common/test_async_adapter.py b/tests/unit/common/test_async_adapter.py deleted file mode 100644 index 800b85d42..000000000 --- a/tests/unit/common/test_async_adapter.py +++ /dev/null @@ -1,145 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import inspect -import asyncio -import logging -import azure.iot.device.common.async_adapter as async_adapter - -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.asyncio - - -@pytest.fixture -def dummy_value(): - return 123 - - -@pytest.fixture -def mock_function(mocker, dummy_value): - mock_fn = mocker.MagicMock(return_value=dummy_value) - mock_fn.__doc__ = "docstring" - return mock_fn - - -@pytest.mark.describe("emulate_async()") -class TestEmulateAsync(object): - @pytest.mark.it("Returns a coroutine function when given a function") - async def test_returns_coroutine(self, mock_function): - async_fn = async_adapter.emulate_async(mock_function) - assert inspect.iscoroutinefunction(async_fn) - - @pytest.mark.it( - "Returns a coroutine function that returns the result of the input function when called" - ) - async def test_coroutine_returns_input_function_result( - self, mocker, mock_function, dummy_value - ): - async_fn = async_adapter.emulate_async(mock_function) - result = await async_fn(dummy_value) - assert mock_function.call_count == 1 - assert mock_function.call_args == mocker.call(dummy_value) - assert result == mock_function.return_value - - @pytest.mark.it("Copies the input function docstring to resulting coroutine function") - async def test_coroutine_has_input_function_docstring(self, mock_function): - async_fn = async_adapter.emulate_async(mock_function) - assert async_fn.__doc__ == mock_function.__doc__ - - @pytest.mark.it("Can be applied as a decorator") - async def test_applied_as_decorator(self): - - # Define a function with emulate_async applied as a decorator - @async_adapter.emulate_async - def some_function(): - return "foo" - - # Call the function as a coroutine - result = await some_function() - assert result == "foo" - - -@pytest.mark.describe("AwaitableCallback") -class TestAwaitableCallback(object): - @pytest.mark.it("Can be instantiated with no args") - async def test_instantiates_without_return_arg_name(self): - callback = async_adapter.AwaitableCallback() - assert isinstance(callback, async_adapter.AwaitableCallback) - - @pytest.mark.it("Can be instantiated with a return_arg_name") - async def test_instantiates_with_return_arg_name(self): - callback = async_adapter.AwaitableCallback(return_arg_name="arg_name") - assert isinstance(callback, async_adapter.AwaitableCallback) - - @pytest.mark.it("Raises a TypeError if return_arg_name is not a string") - async def test_value_error_on_bad_return_arg_name(self): - with pytest.raises(TypeError): - async_adapter.AwaitableCallback(return_arg_name=1) - - @pytest.mark.it( - "Completes the instance Future when a call is invoked on the instance (without return_arg_name)" - ) - async def test_calling_object_completes_future(self): - callback = async_adapter.AwaitableCallback() - assert not callback.future.done() - callback() - await asyncio.sleep(0.1) # wait to give time to complete the callback - assert callback.future.done() - assert not callback.future.exception() - await callback.completion() - - @pytest.mark.it( - "Completes the instance Future when a call is invoked on the instance (with return_arg_name)" - ) - async def test_calling_object_completes_future_with_return_arg_name( - self, fake_return_arg_value - ): - callback = async_adapter.AwaitableCallback(return_arg_name="arg_name") - assert not callback.future.done() - callback(arg_name=fake_return_arg_value) - await asyncio.sleep(0.1) # wait to give time to complete the callback - assert callback.future.done() - assert not callback.future.exception() - assert await callback.completion() == fake_return_arg_value - - @pytest.mark.it( - "Raises a TypeError when a call is invoked on the instance without the correct return argument (with return_arg_name)" - ) - async def test_calling_object_raises_exception_if_return_arg_is_missing( - self, fake_return_arg_value - ): - callback = async_adapter.AwaitableCallback(return_arg_name="arg_name") - with pytest.raises(TypeError): - callback() - - @pytest.mark.it( - "Causes an error to be set on the instance Future when an error parameter is passed to the call (without return_arg_name)" - ) - async def test_raises_error_without_return_arg_name(self, arbitrary_exception): - callback = async_adapter.AwaitableCallback() - assert not callback.future.done() - callback(error=arbitrary_exception) - await asyncio.sleep(0.1) # wait to give time to complete the callback - assert callback.future.done() - assert callback.future.exception() == arbitrary_exception - with pytest.raises(arbitrary_exception.__class__) as e_info: - await callback.completion() - assert e_info.value is arbitrary_exception - - @pytest.mark.it( - "Causes an error to be set on the instance Future when an error parameter is passed to the call (with return_arg_name)" - ) - async def test_raises_error_with_return_arg_name(self, arbitrary_exception): - callback = async_adapter.AwaitableCallback(return_arg_name="arg_name") - assert not callback.future.done() - callback(error=arbitrary_exception) - await asyncio.sleep(0.1) # wait to give time to complete the callback - assert callback.future.done() - assert callback.future.exception() == arbitrary_exception - with pytest.raises(arbitrary_exception.__class__) as e_info: - await callback.completion() - assert e_info.value is arbitrary_exception diff --git a/tests/unit/common/test_asyncio_compat.py b/tests/unit/common/test_asyncio_compat.py deleted file mode 100644 index 6b3890e4f..000000000 --- a/tests/unit/common/test_asyncio_compat.py +++ /dev/null @@ -1,146 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import asyncio -import sys -import logging -from azure.iot.device.common import asyncio_compat - -logging.basicConfig(level=logging.DEBUG) - -pytestmark = pytest.mark.asyncio - - -@pytest.mark.describe("get_running_loop()") -class TestGetRunningLoop(object): - @pytest.mark.it("Returns the currently running Event Loop in Python 3.7 or higher") - @pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires Python 3.7+") - async def test_returns_currently_running_event_loop_(self, mocker, event_loop): - spy_get_running_loop = mocker.spy(asyncio, "get_running_loop") - result = asyncio_compat.get_running_loop() - assert result == event_loop - assert spy_get_running_loop.call_count == 1 - assert spy_get_running_loop.call_args == mocker.call() - - @pytest.mark.it( - "Raises a RuntimeError if there is no running Event Loop in Python 3.7 or higher" - ) - @pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires Python 3.7+") - async def test_raises_runtime_error_if_no_running_event_loop(self, mocker): - mocker.patch.object(asyncio, "get_running_loop", side_effect=RuntimeError) - with pytest.raises(RuntimeError): - asyncio_compat.get_running_loop() - - @pytest.mark.it("Returns the currently running Event Loop in Python 3.6") - @pytest.mark.skipif(sys.version_info >= (3, 7), reason="Requires Python 3.6") - async def test_returns_currently_running_event_loop_py36orless_compat(self, mocker, event_loop): - spy_get_event_loop = mocker.spy(asyncio, "_get_running_loop") - result = asyncio_compat.get_running_loop() - assert result == event_loop - assert spy_get_event_loop.call_count == 1 - assert spy_get_event_loop.call_args == mocker.call() - - @pytest.mark.it("Raises a RuntimeError if there is no running Event Loop in Python 3.6") - @pytest.mark.skipif(sys.version_info >= (3, 7), reason="Requires Python 3.6") - async def test_raises_runtime_error_if_no_running_event_loop_py36orless_compat(self, mocker): - mocker.patch.object(asyncio, "_get_running_loop", return_value=None) - with pytest.raises(RuntimeError): - asyncio_compat.get_running_loop() - - -@pytest.mark.describe("create_task()") -class TestCreateTask(object): - @pytest.fixture - def dummy_coroutine(self): - async def coro(): - return - - return coro - - @pytest.mark.it( - "Returns a Task that wraps a given coroutine, and schedules its execution, in Python 3.7 or higher" - ) - @pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires Python 3.7+") - async def test_returns_task_wrapping_given_coroutine(self, mocker, dummy_coroutine): - spy_create_task = mocker.spy(asyncio, "create_task") - coro_obj = dummy_coroutine() - result = asyncio_compat.create_task(coro_obj) - assert isinstance(result, asyncio.Task) - assert spy_create_task.call_count == 1 - assert spy_create_task.call_args == mocker.call(coro_obj) - - @pytest.mark.it( - "Returns a Task that wraps a given coroutine, and schedules its execution, in Python 3.6" - ) - @pytest.mark.skipif(sys.version_info >= (3, 7), reason="Requires Python 3.6") - async def test_returns_task_wrapping_given_coroutine_py36orless_compat( - self, mocker, dummy_coroutine - ): - spy_ensure_future = mocker.spy(asyncio, "ensure_future") - coro_obj = dummy_coroutine() - result = asyncio_compat.create_task(coro_obj) - assert isinstance(result, asyncio.Task) - assert spy_ensure_future.call_count == 1 - assert spy_ensure_future.call_args == mocker.call(coro_obj) - - -@pytest.mark.describe("run()") -class TestRun(object): - @pytest.mark.it("Runs the given coroutine on a new event loop in Python 3.7 or higher") - @pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires Python 3.7+") - def test_run_37_or_greater(self, mocker): - mock_asyncio_run = mocker.patch.object(asyncio, "run") - mock_coro = mocker.MagicMock() - result = asyncio_compat.run(mock_coro) - assert mock_asyncio_run.call_count == 1 - assert mock_asyncio_run.call_args == mocker.call(mock_coro) - assert result == mock_asyncio_run.return_value - - @pytest.mark.it("Runs the given coroutine on a new event loop in Python 3.6") - @pytest.mark.skipif(sys.version_info >= (3, 7), reason="Requires Python 3.6") - def test_run_36orless_compat(self, mocker): - mock_new_event_loop = mocker.patch.object(asyncio, "new_event_loop") - mock_set_event_loop = mocker.patch.object(asyncio, "set_event_loop") - mock_loop = mock_new_event_loop.return_value - - mock_coro = mocker.MagicMock() - result = asyncio_compat.run(mock_coro) - - # New event loop was created and set - assert mock_new_event_loop.call_count == 1 - assert mock_new_event_loop.call_args == mocker.call() - assert mock_set_event_loop.call_count == 2 # (second time at the end) - assert mock_set_event_loop.call_args_list[0] == mocker.call(mock_loop) - # Coroutine was run on the event loop, with the result returned - assert mock_loop.run_until_complete.call_count == 1 - assert mock_loop.run_until_complete.call_args == mocker.call(mock_coro) - assert result == mock_loop.run_until_complete.return_value - # Loop was closed after completion - assert mock_loop.close.call_count == 1 - assert mock_loop.close.call_args == mocker.call() - # Event loop was set back to None - assert mock_set_event_loop.call_args_list[1] == mocker.call(None) - - @pytest.mark.it( - "Closes the event loop and resets to None, even if an error occurs running the coroutine, in Python 3.6" - ) - @pytest.mark.skipif(sys.version_info >= (3, 7), reason="Requires Python 3.6") - def test_error_running_36orless_compat(self, mocker, arbitrary_exception): - # NOTE: This test is not necessary for 3.7 because asyncio.run() does this for us - mock_new_event_loop = mocker.patch.object(asyncio, "new_event_loop") - mock_set_event_loop = mocker.patch.object(asyncio, "set_event_loop") - mock_loop = mock_new_event_loop.return_value - mock_loop.run_until_complete.side_effect = arbitrary_exception - - mock_coro = mocker.MagicMock() - with pytest.raises(type(arbitrary_exception)): - asyncio_compat.run(mock_coro) - - assert mock_loop.close.call_count == 1 - assert mock_set_event_loop.call_count == 2 # Once set, once to unset - assert mock_set_event_loop.call_args_list[0] == mocker.call(mock_loop) - assert mock_set_event_loop.call_args_list[1] == mocker.call(None) diff --git a/tests/unit/common/test_evented_callback.py b/tests/unit/common/test_evented_callback.py deleted file mode 100644 index fb4a0761d..000000000 --- a/tests/unit/common/test_evented_callback.py +++ /dev/null @@ -1,90 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from time import sleep -from azure.iot.device.common.evented_callback import EventedCallback - -logging.basicConfig(level=logging.INFO) - - -@pytest.mark.describe("EventedCallback") -class TestEventedCallback(object): - @pytest.mark.it("Can be instantiated with no args") - def test_instantiates_without_return_arg_name(self): - callback = EventedCallback() - assert isinstance(callback, EventedCallback) - - @pytest.mark.it("Can be instantiated with a return_arg_name") - def test_instantiates_with_return_arg_name(self): - callback = EventedCallback(return_arg_name="arg_name") - assert isinstance(callback, EventedCallback) - - @pytest.mark.it("Raises a TypeError if return_arg_name is not a string") - def test_value_error_on_bad_return_arg_name(self): - with pytest.raises(TypeError): - EventedCallback(return_arg_name=1) - - @pytest.mark.it( - "Sets the instance completion Event when a call is invoked on the instance (without return_arg_name)" - ) - def test_calling_object_sets_event(self): - callback = EventedCallback() - assert not callback.completion_event.isSet() - callback() - sleep(0.1) # wait to give time to complete the callback - assert callback.completion_event.isSet() - assert not callback.exception - callback.wait_for_completion() - - @pytest.mark.it( - "Sets the instance completion Event when a call is invoked on the instance (with return_arg_name)" - ) - def test_calling_object_sets_event_with_return_arg_name(self, fake_return_arg_value): - callback = EventedCallback(return_arg_name="arg_name") - assert not callback.completion_event.isSet() - callback(arg_name=fake_return_arg_value) - sleep(0.1) # wait to give time to complete the callback - assert callback.completion_event.isSet() - assert not callback.exception - assert callback.wait_for_completion() == fake_return_arg_value - - @pytest.mark.it( - "Raises a TypeError when a call is invoked on the instance without the correct return argument (with return_arg_name)" - ) - def test_calling_object_raises_exception_if_return_arg_is_missing(self, fake_return_arg_value): - callback = EventedCallback(return_arg_name="arg_name") - with pytest.raises(TypeError): - callback() - - @pytest.mark.it( - "Causes an error to be raised from the wait call when an error parameter is passed to the call (without return_arg_name)" - ) - def test_raises_error_without_return_arg_name(self, arbitrary_exception): - callback = EventedCallback() - assert not callback.completion_event.isSet() - callback(error=arbitrary_exception) - sleep(0.1) # wait to give time to complete the callback - assert callback.completion_event.isSet() - assert callback.exception == arbitrary_exception - with pytest.raises(arbitrary_exception.__class__) as e_info: - callback.wait_for_completion() - assert e_info.value is arbitrary_exception - - @pytest.mark.it( - "Causes an error to be raised from the wait call when an error parameter is passed to the call (with return_arg_name)" - ) - def test_raises_error_with_return_arg_name(self, arbitrary_exception): - callback = EventedCallback(return_arg_name="arg_name") - assert not callback.completion_event.isSet() - callback(error=arbitrary_exception) - sleep(0.1) # wait to give time to complete the callback - assert callback.completion_event.isSet() - assert callback.exception == arbitrary_exception - with pytest.raises(arbitrary_exception.__class__) as e_info: - callback.wait_for_completion() - assert e_info.value is arbitrary_exception diff --git a/tests/unit/common/test_http_transport.py b/tests/unit/common/test_http_transport.py deleted file mode 100644 index 9f89d14b9..000000000 --- a/tests/unit/common/test_http_transport.py +++ /dev/null @@ -1,385 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.common.http_transport import HTTPTransport, HTTP_TIMEOUT -from azure.iot.device.common.models import X509, ProxyOptions -from azure.iot.device.common import transport_exceptions as errors -import pytest -import logging -import ssl -import urllib3 -import requests - -logging.basicConfig(level=logging.DEBUG) - -# Monkeypatch to bypass the decorator that runs on a separate thread -HTTPTransport.request = HTTPTransport.request.__wrapped__ - -fake_hostname = "fake.hostname" -fake_path = "path/to/resource" - - -fake_server_verification_cert = "__fake_server_verification_cert__" -fake_x509_cert = "__fake_x509_certificate__" -fake_cipher = "DHE-RSA-AES128-SHA" - - -@pytest.mark.describe("HTTPTransport - Instantiation") -class TestInstantiation(object): - @pytest.fixture( - params=["HTTP - No Auth", "HTTP - Auth", "SOCKS4", "SOCKS5 - No Auth", "SOCKS5 - Auth"] - ) - def proxy_options(self, request): - if "HTTP" in request.param: - proxy_type = "HTTP" - elif "SOCKS4" in request.param: - proxy_type = "SOCKS4" - else: - proxy_type = "SOCKS5" - - if "No Auth" in request.param: - proxy = ProxyOptions(proxy_type=proxy_type, proxy_addr="127.0.0.1", proxy_port=1080) - else: - proxy = ProxyOptions( - proxy_type=proxy_type, - proxy_addr="127.0.0.1", - proxy_port=1080, - proxy_username="fake_username", - proxy_password="fake_password", - ) - return proxy - - @pytest.mark.it("Stores the hostname for later use") - def test_sets_required_parameters(self, mocker): - - mocker.patch.object(ssl, "SSLContext").return_value - mocker.patch.object(HTTPTransport, "_create_ssl_context").return_value - - http_transport_object = HTTPTransport( - hostname=fake_hostname, - server_verification_cert=fake_server_verification_cert, - x509_cert=fake_x509_cert, - cipher=fake_cipher, - ) - - assert http_transport_object._hostname == fake_hostname - - @pytest.mark.it( - "Creates a dictionary of proxies from the 'proxy_options' parameter, if the parameter is provided" - ) - def test_proxy_format(self, proxy_options): - http_transport_object = HTTPTransport(hostname=fake_hostname, proxy_options=proxy_options) - - if proxy_options.proxy_username and proxy_options.proxy_password: - expected_proxy_string = "{username}:{password}@{address}:{port}".format( - username=proxy_options.proxy_username, - password=proxy_options.proxy_password, - address=proxy_options.proxy_address, - port=proxy_options.proxy_port, - ) - else: - expected_proxy_string = "{address}:{port}".format( - address=proxy_options.proxy_address, port=proxy_options.proxy_port - ) - - if proxy_options.proxy_type == "HTTP": - expected_proxy_string = "http://" + expected_proxy_string - elif proxy_options.proxy_type == "SOCKS4": - expected_proxy_string = "socks4://" + expected_proxy_string - else: - expected_proxy_string = "socks5://" + expected_proxy_string - - assert isinstance(http_transport_object._proxies, dict) - assert http_transport_object._proxies["http"] == expected_proxy_string - assert http_transport_object._proxies["https"] == expected_proxy_string - - @pytest.mark.it( - "Configures TLS/SSL context to use TLS 1.2, require certificates and check hostname" - ) - def test_configures_tls_context(self, mocker): - mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = mock_ssl_context_constructor.return_value - - HTTPTransport(hostname=fake_hostname) - # Verify correctness of TLS/SSL Context - assert mock_ssl_context_constructor.call_count == 1 - assert mock_ssl_context_constructor.call_args == mocker.call(protocol=ssl.PROTOCOL_TLSv1_2) - assert mock_ssl_context.check_hostname is True - assert mock_ssl_context.verify_mode == ssl.CERT_REQUIRED - - @pytest.mark.it( - "Configures TLS/SSL context using default certificates if protocol wrapper not instantiated with a server verification certificate" - ) - def test_configures_tls_context_with_default_certs(self, mocker): - mock_ssl_context = mocker.patch.object(ssl, "SSLContext").return_value - - HTTPTransport(hostname=fake_hostname) - - assert mock_ssl_context.load_default_certs.call_count == 1 - assert mock_ssl_context.load_default_certs.call_args == mocker.call() - - @pytest.mark.it( - "Configures TLS/SSL context with provided server verification certificate if protocol wrapper instantiated with a server verification certificate" - ) - def test_configures_tls_context_with_server_verification_certs(self, mocker): - mock_ssl_context = mocker.patch.object(ssl, "SSLContext").return_value - - HTTPTransport( - hostname=fake_hostname, server_verification_cert=fake_server_verification_cert - ) - - assert mock_ssl_context.load_verify_locations.call_count == 1 - assert mock_ssl_context.load_verify_locations.call_args == mocker.call( - cadata=fake_server_verification_cert - ) - - @pytest.mark.it( - "Configures TLS/SSL context with provided cipher if present during instantiation" - ) - def test_configures_tls_context_with_cipher(self, mocker): - mock_ssl_context = mocker.patch.object(ssl, "SSLContext").return_value - - HTTPTransport(hostname=fake_hostname, cipher=fake_cipher) - - assert mock_ssl_context.set_ciphers.call_count == 1 - assert mock_ssl_context.set_ciphers.call_args == mocker.call(fake_cipher) - - @pytest.mark.it("Configures TLS/SSL context with client-provided-certificate-chain like x509") - def test_configures_tls_context_with_client_provided_certificate_chain(self, mocker): - fake_client_cert = X509("fake_cert_file", "fake_key_file", "fake pass phrase") - mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = mock_ssl_context_constructor.return_value - - HTTPTransport(hostname=fake_hostname, x509_cert=fake_client_cert) - - assert mock_ssl_context.load_default_certs.call_count == 1 - assert mock_ssl_context.load_cert_chain.call_count == 1 - assert mock_ssl_context.load_cert_chain.call_args == mocker.call( - fake_client_cert.certificate_file, - fake_client_cert.key_file, - fake_client_cert.pass_phrase, - ) - - @pytest.mark.it( - "Creates a custom requests HTTP Adapter that uses the configured SSL context when creating PoolManagers" - ) - def test_http_adapter_pool_manager(self, mocker): - # NOTE: This test involves mocking and testing deeper parts of the requests library stack - # in order to show that the HTTPAdapter is functioning as intended. This naturally gets a - # little messy from a unit testing perspective - poolmanager_init_mock = mocker.patch.object(requests.adapters, "PoolManager") - proxymanager_init_mock = mocker.patch.object(urllib3.poolmanager, "ProxyManager") - socksproxymanager_init_mock = mocker.patch.object(requests.adapters, "SOCKSProxyManager") - ssl_context_init_mock = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = ssl_context_init_mock.return_value - - http_transport_object = HTTPTransport(hostname=fake_hostname) - # SSL Context was only created once - assert ssl_context_init_mock.call_count == 1 - # HTTP Adapter was set on the transport - assert isinstance(http_transport_object._http_adapter, requests.adapters.HTTPAdapter) - - # Reset the poolmanager mock because it's already been called upon instantiation of the adapter - # (We will manually test scenarios in which a PoolManager is instantiated) - poolmanager_init_mock.reset_mock() - - # Basic PoolManager init scenario - http_transport_object._http_adapter.init_poolmanager( - connections=requests.adapters.DEFAULT_POOLSIZE, - maxsize=requests.adapters.DEFAULT_POOLSIZE, - ) - assert poolmanager_init_mock.call_count == 1 - assert poolmanager_init_mock.call_args[1]["ssl_context"] == mock_ssl_context - - # ProxyManager init scenario - http_transport_object._http_adapter.proxy_manager_for(proxy="http://127.0.0.1") - assert proxymanager_init_mock.call_count == 1 - assert proxymanager_init_mock.call_args[1]["ssl_context"] == mock_ssl_context - - # SOCKSProxyManager init scenario - http_transport_object._http_adapter.proxy_manager_for(proxy="socks5://127.0.0.1") - assert socksproxymanager_init_mock.call_count == 1 - assert socksproxymanager_init_mock.call_args[1]["ssl_context"] == mock_ssl_context - - # SSL Context was still only ever created once. This proves that the SSL context being - # used above is the same one that was configured in a custom way - assert ssl_context_init_mock.call_count == 1 - - -@pytest.mark.describe("HTTPTransport - .request()") -class TestRequest(object): - @pytest.fixture(autouse=True) - def mock_requests_session(self, mocker): - return mocker.patch.object(requests, "Session") - - @pytest.fixture - def session(self, mock_requests_session): - return mock_requests_session.return_value - - @pytest.fixture - def transport(self): - return HTTPTransport(hostname=fake_hostname) - - @pytest.fixture(params=["GET", "POST", "PUT", "PATCH", "DELETE"]) - def request_method(self, request): - return request.param - - @pytest.mark.it( - "Mounts the custom HTTP Adapter on a new requests Session before making a request" - ) - def test_mount_adapter(self, mocker, transport, mock_requests_session, request_method): - session = mock_requests_session.return_value - session_method = getattr(session, request_method.lower()) - - # Check that the request has not yet been made when mounted - def check_request_not_made(*args): - assert session_method.call_count == 0 - - session.mount.side_effect = check_request_not_made - - # Session has not yet been created - assert mock_requests_session.call_count == 0 - - # Request - transport.request(request_method, fake_path, mocker.MagicMock()) - - # Session has been created - assert mock_requests_session.call_count == 1 - assert mock_requests_session.call_args == mocker.call() - assert session is mock_requests_session.return_value - # Adapter has been mounted - assert session.mount.call_count == 1 - assert session.mount.call_args == mocker.call("https://", transport._http_adapter) - # Request was made after (see above side effect for proof that this happens after mount) - assert session_method.call_count == 1 - - @pytest.mark.it( - "Makes a HTTP request with the new Session using the given parameters, stored hostname and stored proxy" - ) - @pytest.mark.parametrize( - "hostname, path, query_params, expected_url", - [ - pytest.param( - "fake.hostname", - "path/to/resource", - "", - "https://fake.hostname/path/to/resource", - id="No query parameters", - ), - pytest.param( - "fake.hostname", - "path/to/resource", - "arg1=val1;arg2=val2", - "https://fake.hostname/path/to/resource?arg1=val1;arg2=val2", - id="With query parameters", - ), - ], - ) - @pytest.mark.parametrize( - "body", [pytest.param("", id="No body"), pytest.param("fake body", id="With body")] - ) - @pytest.mark.parametrize( - "headers", - [pytest.param({}, id="No headers"), pytest.param({"Key": "Value"}, id="With headers")], - ) - def test_request( - self, - mocker, - transport, - mock_requests_session, - request_method, - hostname, - path, - query_params, - expected_url, - body, - headers, - ): - transport._hostname = hostname - transport.request( - method=request_method, - path=path, - callback=mocker.MagicMock(), - body=body, - headers=headers, - query_params=query_params, - ) - - # New session was created - assert mock_requests_session.call_count == 1 - assert mock_requests_session.call_args == mocker.call() - session = mock_requests_session.return_value - assert session.mount.call_count == 1 - assert session.mount.call_args == mocker.call("https://", transport._http_adapter) - - # The relevant method was called on the session - session_method = getattr(session, request_method.lower()) - assert session_method.call_count == 1 - assert session_method.call_args == mocker.call( - expected_url, - data=body, - headers=headers, - proxies=transport._proxies, - timeout=HTTP_TIMEOUT, - ) - - @pytest.mark.it( - "Creates a response object containing the status code, reason and text from the HTTP response and returns it via the callback" - ) - def test_returns_response(self, mocker, transport, session, request_method): - session_method = getattr(session, request_method.lower()) - response = session_method.return_value - cb_mock = mocker.MagicMock() - - transport.request(method=request_method, path=fake_path, callback=cb_mock) - - assert cb_mock.call_count == 1 - assert cb_mock.call_args == mocker.call(response=mocker.ANY) - response_obj = cb_mock.call_args[1]["response"] - assert response_obj["status_code"] == response.status_code - assert response_obj["reason"] == response.reason - assert response_obj["resp"] == response.text - - @pytest.mark.it( - "Returns a ValueError via the callback if the request method provided is not valid" - ) - def test_invalid_method(self, mocker, transport): - cb_mock = mocker.MagicMock() - transport.request(method="NOT A REAL METHOD", path=fake_path, callback=cb_mock) - - assert cb_mock.call_count == 1 - error = cb_mock.call_args[1]["error"] - assert isinstance(error, ValueError) - - @pytest.mark.it( - "Returns a requests.exceptions.Timeout via the callback if the HTTP request times out" - ) - def test_request_timeout(self, mocker, transport, session, request_method): - session_method = getattr(session, request_method.lower()) - session_method.side_effect = requests.exceptions.Timeout - cb_mock = mocker.MagicMock() - - transport.request(method=request_method, path=fake_path, callback=cb_mock) - - assert cb_mock.call_count == 1 - error = cb_mock.call_args[1]["error"] - assert isinstance(error, requests.exceptions.Timeout) - - @pytest.mark.it( - "Returns a ProtocolClientError via the callback if making the HTTP request raises an unexpected Exception" - ) - def test_client_raises_unexpected_error( - self, mocker, transport, session, request_method, arbitrary_exception - ): - session_method = getattr(session, request_method.lower()) - session_method.side_effect = arbitrary_exception - cb_mock = mocker.MagicMock() - - transport.request(method=request_method, path=fake_path, callback=cb_mock) - - assert cb_mock.call_count == 1 - error = cb_mock.call_args[1]["error"] - assert isinstance(error, errors.ProtocolClientError) - assert error.__cause__ is arbitrary_exception diff --git a/tests/unit/common/test_mqtt_transport.py b/tests/unit/common/test_mqtt_transport.py deleted file mode 100644 index ee547d03a..000000000 --- a/tests/unit/common/test_mqtt_transport.py +++ /dev/null @@ -1,2532 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from azure.iot.device.common.mqtt_transport import MQTTTransport, OperationManager -from azure.iot.device.common.models.x509 import X509 -from azure.iot.device.common import transport_exceptions as errors -from azure.iot.device.common import ProxyOptions -import paho.mqtt.client as mqtt -import ssl -import copy -import pytest -import logging -import socket -import socks -import threading -import gc -import weakref - -logging.basicConfig(level=logging.DEBUG) - -fake_hostname = "fake.hostname" -fake_device_id = "MyDevice" -fake_password = "fake_password" -fake_username = fake_hostname + "/" + fake_device_id -new_fake_password = "new fake password" -fake_topic = "fake_topic" -fake_payload = "some payload" -fake_cipher = "DHE-RSA-AES128-SHA" -fake_qos = 1 -fake_mid = 52 -fake_rc = 0 -fake_success_rc = 0 -fake_failed_rc = mqtt.MQTT_ERR_PROTOCOL -failed_connack_rc = mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED -fake_keepalive = 1234 - - -# mapping of Paho connack rc codes to Error object classes -connack_return_codes = [ - { - "name": "CONNACK_REFUSED_PROTOCOL_VERSION", - "rc": mqtt.CONNACK_REFUSED_PROTOCOL_VERSION, - "error": errors.ProtocolClientError, - }, - { - "name": "CONNACK_REFUSED_IDENTIFIER_REJECTED", - "rc": mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED, - "error": errors.ProtocolClientError, - }, - { - "name": "CONNACK_REFUSED_SERVER_UNAVAILABLE", - "rc": mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE, - "error": errors.ConnectionFailedError, - }, - { - "name": "CONNACK_REFUSED_BAD_USERNAME_PASSWORD", - "rc": mqtt.CONNACK_REFUSED_BAD_USERNAME_PASSWORD, - "error": errors.UnauthorizedError, - }, - { - "name": "CONNACK_REFUSED_NOT_AUTHORIZED", - "rc": mqtt.CONNACK_REFUSED_NOT_AUTHORIZED, - "error": errors.UnauthorizedError, - }, -] - - -# mapping of Paho rc codes to Error object classes -operation_return_codes = [ - {"name": "MQTT_ERR_NOMEM", "rc": mqtt.MQTT_ERR_NOMEM, "error": errors.ConnectionDroppedError}, - { - "name": "MQTT_ERR_PROTOCOL", - "rc": mqtt.MQTT_ERR_PROTOCOL, - "error": errors.ProtocolClientError, - }, - {"name": "MQTT_ERR_INVAL", "rc": mqtt.MQTT_ERR_INVAL, "error": errors.ProtocolClientError}, - {"name": "MQTT_ERR_NO_CONN", "rc": mqtt.MQTT_ERR_NO_CONN, "error": errors.NoConnectionError}, - { - "name": "MQTT_ERR_CONN_REFUSED", - "rc": mqtt.MQTT_ERR_CONN_REFUSED, - "error": errors.ConnectionFailedError, - }, - { - "name": "MQTT_ERR_NOT_FOUND", - "rc": mqtt.MQTT_ERR_NOT_FOUND, - "error": errors.ConnectionFailedError, - }, - { - "name": "MQTT_ERR_CONN_LOST", - "rc": mqtt.MQTT_ERR_CONN_LOST, - "error": errors.ConnectionDroppedError, - }, - {"name": "MQTT_ERR_TLS", "rc": mqtt.MQTT_ERR_TLS, "error": errors.UnauthorizedError}, - { - "name": "MQTT_ERR_PAYLOAD_SIZE", - "rc": mqtt.MQTT_ERR_PAYLOAD_SIZE, - "error": errors.ProtocolClientError, - }, - { - "name": "MQTT_ERR_NOT_SUPPORTED", - "rc": mqtt.MQTT_ERR_NOT_SUPPORTED, - "error": errors.ProtocolClientError, - }, - {"name": "MQTT_ERR_AUTH", "rc": mqtt.MQTT_ERR_AUTH, "error": errors.UnauthorizedError}, - { - "name": "MQTT_ERR_ACL_DENIED", - "rc": mqtt.MQTT_ERR_ACL_DENIED, - "error": errors.UnauthorizedError, - }, - {"name": "MQTT_ERR_UNKNOWN", "rc": mqtt.MQTT_ERR_UNKNOWN, "error": errors.ProtocolClientError}, - {"name": "MQTT_ERR_ERRNO", "rc": mqtt.MQTT_ERR_ERRNO, "error": errors.ProtocolClientError}, - { - "name": "MQTT_ERR_QUEUE_SIZE", - "rc": mqtt.MQTT_ERR_QUEUE_SIZE, - "error": errors.ProtocolClientError, - }, - { - "name": "MQTT_ERR_KEEPALIVE", - "rc": mqtt.MQTT_ERR_KEEPALIVE, - "error": errors.ConnectionDroppedError, - }, -] - - -@pytest.fixture -def mock_mqtt_client(mocker, fake_paho_thread): - mock = mocker.patch.object(mqtt, "Client") - mock_mqtt_client = mock.return_value - mock_mqtt_client.subscribe = mocker.MagicMock(return_value=(fake_rc, fake_mid)) - mock_mqtt_client.unsubscribe = mocker.MagicMock(return_value=(fake_rc, fake_mid)) - mock_mqtt_client.publish = mocker.MagicMock(return_value=(fake_rc, fake_mid)) - mock_mqtt_client.connect.return_value = 0 - mock_mqtt_client.reconnect.return_value = 0 - mock_mqtt_client.disconnect.return_value = 0 - mock_mqtt_client._thread = fake_paho_thread - return mock_mqtt_client - - -@pytest.fixture -def transport(mock_mqtt_client): - # Implicitly imports the mocked Paho MQTT Client from mock_mqtt_client - return MQTTTransport(client_id=fake_device_id, hostname=fake_hostname, username=fake_username) - - -@pytest.fixture -def fake_paho_thread(mocker): - thread = mocker.MagicMock(spec=threading.Thread) - thread.name = "_fake_paho_thread_" - return thread - - -@pytest.fixture -def mock_paho_thread_current(mocker, fake_paho_thread): - return mocker.patch.object(threading, "current_thread", return_value=fake_paho_thread) - - -@pytest.fixture -def fake_non_paho_thread(mocker): - thread = mocker.MagicMock(spec=threading.Thread) - thread.name = "_fake_non_paho_thread_" - return thread - - -@pytest.fixture -def mock_non_paho_thread_current(mocker, fake_non_paho_thread): - return mocker.patch.object(threading, "current_thread", return_value=fake_non_paho_thread) - - -@pytest.mark.describe("MQTTTransport - Instantiation") -class TestInstantiation(object): - @pytest.fixture( - params=["HTTP - No Auth", "HTTP - Auth", "SOCKS4", "SOCKS5 - No Auth", "SOCKS5 - Auth"] - ) - def proxy_options(self, request): - if "HTTP" in request.param: - proxy_type = "HTTP" - elif "SOCKS4" in request.param: - proxy_type = "SOCKS4" - else: - proxy_type = "SOCKS5" - - if "No Auth" in request.param: - proxy = ProxyOptions(proxy_type=proxy_type, proxy_addr="fake.address", proxy_port=1080) - else: - proxy = ProxyOptions( - proxy_type=proxy_type, - proxy_addr="fake.address", - proxy_port=1080, - proxy_username="fake_username", - proxy_password="fake_password", - ) - return proxy - - @pytest.mark.it("Creates an instance of the Paho MQTT Client") - def test_instantiates_mqtt_client(self, mocker): - mock_mqtt_client_constructor = mocker.patch.object(mqtt, "Client") - - MQTTTransport(client_id=fake_device_id, hostname=fake_hostname, username=fake_username) - - assert mock_mqtt_client_constructor.call_count == 1 - assert mock_mqtt_client_constructor.call_args == mocker.call( - client_id=fake_device_id, clean_session=False, protocol=mqtt.MQTTv311 - ) - - @pytest.mark.it( - "Creates an instance of the Paho MQTT Client using Websockets when websockets parameter is True" - ) - def test_configures_mqtt_websockets(self, mocker): - mock_mqtt_client_constructor = mocker.patch.object(mqtt, "Client") - mock_mqtt_client = mock_mqtt_client_constructor.return_value - - MQTTTransport( - client_id=fake_device_id, - hostname=fake_hostname, - username=fake_username, - websockets=True, - ) - - assert mock_mqtt_client_constructor.call_count == 1 - assert mock_mqtt_client_constructor.call_args == mocker.call( - client_id=fake_device_id, - clean_session=False, - protocol=mqtt.MQTTv311, - transport="websockets", - ) - - # Verify websockets options have been set - assert mock_mqtt_client.ws_set_options.call_count == 1 - assert mock_mqtt_client.ws_set_options.call_args == mocker.call(path="/$iothub/websocket") - - @pytest.mark.it( - "Sets the proxy information on the client when the `proxy_options` parameter is provided" - ) - def test_proxy_config(self, mocker, proxy_options): - mock_mqtt_client_constructor = mocker.patch.object(mqtt, "Client") - mock_mqtt_client = mock_mqtt_client_constructor.return_value - - MQTTTransport( - client_id=fake_device_id, - hostname=fake_hostname, - username=fake_username, - proxy_options=proxy_options, - ) - - # Verify proxy has been set - assert mock_mqtt_client.proxy_set.call_count == 1 - assert mock_mqtt_client.proxy_set.call_args == mocker.call( - proxy_type=proxy_options.proxy_type_socks, - proxy_addr=proxy_options.proxy_address, - proxy_port=proxy_options.proxy_port, - proxy_username=proxy_options.proxy_username, - proxy_password=proxy_options.proxy_password, - ) - - @pytest.mark.it( - "Configures TLS/SSL context to use TLS 1.2, require certificates and check hostname" - ) - def test_configures_tls_context(self, mocker): - mock_mqtt_client = mocker.patch.object(mqtt, "Client").return_value - mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = mock_ssl_context_constructor.return_value - - MQTTTransport(client_id=fake_device_id, hostname=fake_hostname, username=fake_username) - - # Verify correctness of TLS/SSL Context - assert mock_ssl_context_constructor.call_count == 1 - assert mock_ssl_context_constructor.call_args == mocker.call(protocol=ssl.PROTOCOL_TLSv1_2) - assert mock_ssl_context.check_hostname is True - assert mock_ssl_context.verify_mode == ssl.CERT_REQUIRED - - # Verify context has been set - assert mock_mqtt_client.tls_set_context.call_count == 1 - assert mock_mqtt_client.tls_set_context.call_args == mocker.call(context=mock_ssl_context) - - @pytest.mark.it( - "Configures TLS/SSL context using default certificates if protocol wrapper not instantiated with a server verification certificate" - ) - def test_configures_tls_context_with_default_certs(self, mocker, mock_mqtt_client): - mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = mock_ssl_context_constructor.return_value - - MQTTTransport(client_id=fake_device_id, hostname=fake_hostname, username=fake_username) - - assert mock_ssl_context.load_default_certs.call_count == 1 - assert mock_ssl_context.load_default_certs.call_args == mocker.call() - - @pytest.mark.it( - "Configures TLS/SSL context with provided server verification certificate if protocol wrapper instantiated with a server verification certificate" - ) - def test_configures_tls_context_with_server_verification_certs(self, mocker, mock_mqtt_client): - mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = mock_ssl_context_constructor.return_value - server_verification_cert = "dummy_certificate" - - MQTTTransport( - client_id=fake_device_id, - hostname=fake_hostname, - username=fake_username, - server_verification_cert=server_verification_cert, - ) - - assert mock_ssl_context.load_verify_locations.call_count == 1 - assert mock_ssl_context.load_verify_locations.call_args == mocker.call( - cadata=server_verification_cert - ) - - @pytest.mark.it( - "Configures TLS/SSL context with provided cipher if present during instantiation" - ) - def test_configures_tls_context_with_cipher(self, mocker, mock_mqtt_client): - mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = mock_ssl_context_constructor.return_value - - MQTTTransport( - client_id=fake_device_id, - hostname=fake_hostname, - username=fake_username, - cipher=fake_cipher, - ) - - assert mock_ssl_context.set_ciphers.call_count == 1 - assert mock_ssl_context.set_ciphers.call_args == mocker.call(fake_cipher) - - @pytest.mark.it("Configures TLS/SSL context with client-provided-certificate-chain like x509") - def test_configures_tls_context_with_client_provided_certificate_chain( - self, mocker, mock_mqtt_client - ): - mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") - mock_ssl_context = mock_ssl_context_constructor.return_value - fake_client_cert = X509("fake_cert_file", "fake_key_file", "fake pass phrase") - - MQTTTransport( - client_id=fake_device_id, - hostname=fake_hostname, - username=fake_username, - x509_cert=fake_client_cert, - ) - - assert mock_ssl_context.load_default_certs.call_count == 1 - assert mock_ssl_context.load_cert_chain.call_count == 1 - assert mock_ssl_context.load_cert_chain.call_args == mocker.call( - fake_client_cert.certificate_file, - fake_client_cert.key_file, - fake_client_cert.pass_phrase, - ) - - @pytest.mark.it("Sets Paho MQTT Client callbacks") - def test_sets_paho_callbacks(self, mocker): - mock_mqtt_client = mocker.patch.object(mqtt, "Client").return_value - - MQTTTransport(client_id=fake_device_id, hostname=fake_hostname, username=fake_username) - - assert callable(mock_mqtt_client.on_connect) - assert callable(mock_mqtt_client.on_disconnect) - assert callable(mock_mqtt_client.on_subscribe) - assert callable(mock_mqtt_client.on_unsubscribe) - assert callable(mock_mqtt_client.on_publish) - assert callable(mock_mqtt_client.on_message) - - @pytest.mark.it("Initializes event handlers to 'None'") - def test_handler_callbacks_set_to_none(self, mocker): - mocker.patch.object(mqtt, "Client") - - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - - assert transport.on_mqtt_connected_handler is None - assert transport.on_mqtt_disconnected_handler is None - assert transport.on_mqtt_message_received_handler is None - - @pytest.mark.it("Initializes internal operation tracking structures") - def test_operation_infrastructure_set_up(self, mocker): - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - assert transport._op_manager._pending_operation_callbacks == {} - assert transport._op_manager._unknown_operation_completions == {} - - @pytest.mark.it("Sets paho auto-reconnect interval to 2 hours") - def test_sets_reconnect_interval(self, mocker, transport, mock_mqtt_client): - MQTTTransport(client_id=fake_device_id, hostname=fake_hostname, username=fake_username) - - # called once by the mqtt_client constructor and once by mqtt_transport.py - assert mock_mqtt_client.reconnect_delay_set.call_count == 2 - assert mock_mqtt_client.reconnect_delay_set.call_args == mocker.call(120 * 60) - - -@pytest.mark.describe("MQTTTransport - .shutdown()") -class TestShutdown(object): - @pytest.mark.it("Force Disconnects Paho") - def test_disconnects(self, mocker, mock_mqtt_client, transport): - transport.shutdown() - - assert mock_mqtt_client.disconnect.call_count == 1 - assert mock_mqtt_client.disconnect.call_args == mocker.call() - assert mock_mqtt_client.loop_stop.call_count == 1 - assert mock_mqtt_client.loop_stop.call_args == mocker.call() - - @pytest.mark.it("Does NOT trigger the on_disconnect handler upon disconnect") - def test_does_not_trigger_handler(self, mocker, mock_mqtt_client, transport): - mock_disconnect_handler = mocker.MagicMock() - mock_mqtt_client.on_disconnect = mock_disconnect_handler - transport.shutdown() - assert mock_mqtt_client.on_disconnect is None - assert mock_disconnect_handler.call_count == 0 - - -class ArbitraryConnectException(Exception): - pass - - -@pytest.mark.describe("MQTTTransport - .connect()") -class TestConnect(object): - @pytest.mark.it("Uses the stored username and provided password for Paho credentials") - def test_use_provided_password(self, mocker, mock_mqtt_client, transport): - transport.connect(fake_password) - - assert mock_mqtt_client.username_pw_set.call_count == 1 - assert mock_mqtt_client.username_pw_set.call_args == mocker.call( - username=transport._username, password=fake_password - ) - - @pytest.mark.it( - "Uses the stored username without a password for Paho credentials, if password is not provided" - ) - def test_use_no_password(self, mocker, mock_mqtt_client, transport): - transport.connect() - - assert mock_mqtt_client.username_pw_set.call_count == 1 - assert mock_mqtt_client.username_pw_set.call_args == mocker.call( - username=transport._username, password=None - ) - - @pytest.mark.it("Initiates MQTT connect via Paho") - @pytest.mark.parametrize( - "password", - [ - pytest.param(fake_password, id="Password provided"), - pytest.param(None, id="No password provided"), - ], - ) - @pytest.mark.parametrize( - "websockets,port", - [ - pytest.param(False, 8883, id="Not using websockets"), - pytest.param(True, 443, id="Using websockets"), - ], - ) - def test_calls_paho_connect( - self, mocker, mock_mqtt_client, transport, password, websockets, port - ): - - # We don't want to use a special fixture for websockets, so instead we are overriding the attribute below. - # However, we want to assert that this value is not undefined. For instance, the self._websockets convention private attribute - # could be changed to self._websockets1, and all our tests would still pass without the below assert statement. - assert transport._websockets is False - - transport._websockets = websockets - fake_keepalive = 900 - transport._keep_alive = fake_keepalive - - transport.connect(password) - - assert mock_mqtt_client.connect.call_count == 1 - assert mock_mqtt_client.connect.call_args == mocker.call( - host=fake_hostname, port=port, keepalive=fake_keepalive - ) - - @pytest.mark.it("Starts MQTT Network Loop") - @pytest.mark.parametrize( - "password", - [ - pytest.param(fake_password, id="Password provided"), - pytest.param(None, id="No password provided"), - ], - ) - def test_calls_loop_start(self, mocker, mock_mqtt_client, transport, password): - transport.connect(password) - - assert mock_mqtt_client.loop_start.call_count == 1 - assert mock_mqtt_client.loop_start.call_args == mocker.call() - - @pytest.mark.it("Raises a ProtocolClientError if Paho connect raises an unexpected Exception") - def test_client_raises_unexpected_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - mock_mqtt_client.connect.side_effect = arbitrary_exception - with pytest.raises(errors.ProtocolClientError) as e_info: - transport.connect(fake_password) - assert e_info.value.__cause__ is arbitrary_exception - - @pytest.mark.it( - "Raises a ConnectionFailedError if Paho connect raises a socket.error Exception" - ) - def test_client_raises_socket_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - socket_error = socket.error() - mock_mqtt_client.connect.side_effect = socket_error - with pytest.raises(errors.ConnectionFailedError) as e_info: - transport.connect(fake_password) - assert e_info.value.__cause__ is socket_error - - @pytest.mark.it( - "Raises a TlsExchangeAuthError if Paho connect raises a socket.error of type SSLCertVerificationError Exception" - ) - def test_client_raises_socket_tls_auth_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - socket_error = ssl.SSLError("socket error", "CERTIFICATE_VERIFY_FAILED") - mock_mqtt_client.connect.side_effect = socket_error - with pytest.raises(errors.TlsExchangeAuthError) as e_info: - transport.connect(fake_password) - assert e_info.value.__cause__ is socket_error - print(e_info.value.__cause__.strerror) - - @pytest.mark.it( - "Raises a ProtocolProxyError if Paho connect raises a socket error or a ProxyError exception" - ) - def test_client_raises_socket_error_or_proxy_error_as_proxy_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - socks_error = socks.SOCKS5Error( - "it is a sock 5 error", socket_err="a general SOCKS5Error error" - ) - mock_mqtt_client.connect.side_effect = socks_error - with pytest.raises(errors.ProtocolProxyError) as e_info: - transport.connect(fake_password) - assert e_info.value.__cause__ is socks_error - print(e_info.value.__cause__.strerror) - - @pytest.mark.it( - "Raises a UnauthorizedError if Paho connect raises a socket error or a ProxyError exception" - ) - def test_client_raises_socket_error_or_proxy_error_as_unauthorized_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - socks_error = socks.SOCKS5AuthError( - "it is a sock 5 auth error", socket_err="an auth SOCKS5Error error" - ) - mock_mqtt_client.connect.side_effect = socks_error - with pytest.raises(errors.UnauthorizedError) as e_info: - transport.connect(fake_password) - assert e_info.value.__cause__ is socks_error - print(e_info.value.__cause__.strerror) - - @pytest.mark.it("Allows any BaseExceptions raised in Paho connect to propagate") - def test_client_raises_base_exception( - self, mock_mqtt_client, transport, arbitrary_base_exception - ): - mock_mqtt_client.connect.side_effect = arbitrary_base_exception - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.connect(fake_password) - assert e_info.value is arbitrary_base_exception - - # NOTE: this test tests for all possible return codes, even ones that shouldn't be - # possible on a connect operation. - @pytest.mark.it("Raises a custom Exception if Paho connect returns a failing rc code") - @pytest.mark.parametrize( - "error_params", - operation_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], - ) - def test_client_returns_failing_rc_code( - self, mocker, mock_mqtt_client, transport, error_params - ): - mock_mqtt_client.connect.return_value = error_params["rc"] - with pytest.raises(error_params["error"]): - transport.connect(fake_password) - - @pytest.fixture( - params=[ - ArbitraryConnectException(), - socket.error(), - ssl.SSLError("socket error", "CERTIFICATE_VERIFY_FAILED"), - socks.SOCKS5Error("it is a sock 5 error", socket_err="a general SOCKS5Error error"), - socks.SOCKS5AuthError( - "it is a sock 5 auth error", socket_err="an auth SOCKS5Error error" - ), - ], - ids=[ - "ArbitraryConnectException", - "socket.error", - "ssl.SSLError", - "socks.SOCKS5Error", - "socks.SOCKS5AuthError", - ], - ) - def connect_exception(self, request): - return request.param - - @pytest.mark.it("Calls _mqtt_client.disconnect if Paho raises an exception") - def test_calls_disconnect_on_exception( - self, mocker, mock_mqtt_client, transport, connect_exception - ): - mock_mqtt_client.connect.side_effect = connect_exception - with pytest.raises(Exception): - transport.connect(fake_password) - assert mock_mqtt_client.disconnect.call_count == 1 - - @pytest.mark.it("Calls _mqtt_client.loop_stop if Paho raises an exception") - def test_calls_loop_stop_on_exception( - self, mocker, mock_mqtt_client, transport, connect_exception - ): - mock_mqtt_client.connect.side_effect = connect_exception - with pytest.raises(Exception): - transport.connect(fake_password) - assert mock_mqtt_client.loop_stop.call_count == 1 - - @pytest.mark.it( - "Sets Paho's _thread to None if Paho raises an exception while running in the Paho thread" - ) - def test_sets_thread_to_none_on_exception_in_paho_thread( - self, mocker, mock_mqtt_client, transport, mock_paho_thread_current, connect_exception - ): - mock_mqtt_client.connect.side_effect = connect_exception - with pytest.raises(Exception): - transport.connect(fake_password) - assert mock_mqtt_client._thread is None - - @pytest.mark.it( - "Does not sets Paho's _thread to None if Paho raises an exception running outside the Paho thread" - ) - def test_does_not_set_thread_to_none_on_exception_not_in_paho_thread( - self, mocker, mock_mqtt_client, transport, mock_non_paho_thread_current, connect_exception - ): - mock_mqtt_client.connect.side_effect = connect_exception - with pytest.raises(Exception): - transport.connect(fake_password) - assert mock_mqtt_client._thread is not None - - -@pytest.mark.describe("MQTTTransport - OCCURRENCE: Connect Completed") -class TestEventConnectComplete(object): - @pytest.mark.it( - "Triggers on_mqtt_connected_handler event handler upon successful connect completion" - ) - def test_calls_event_handler_callback(self, mocker, mock_mqtt_client, transport): - callback = mocker.MagicMock() - transport.on_mqtt_connected_handler = callback - - # Manually trigger Paho on_connect event_handler - mock_mqtt_client.on_connect(client=mock_mqtt_client, userdata=None, flags=None, rc=fake_rc) - - # Verify transport.on_mqtt_connected_handler was called - assert callback.call_count == 1 - assert callback.call_args == mocker.call() - - @pytest.mark.it( - "Skips on_mqtt_connected_handler event handler if set to 'None' upon successful connect completion" - ) - def test_skips_none_event_handler_callback(self, mocker, mock_mqtt_client, transport): - assert transport.on_mqtt_connected_handler is None - - transport.connect(fake_password) - - mock_mqtt_client.on_connect(client=mock_mqtt_client, userdata=None, flags=None, rc=fake_rc) - - # No further asserts required - this is a test to show that it skips a callback. - # Not raising an exception == test passed - - @pytest.mark.it("Recovers from Exception in on_mqtt_connected_handler event handler") - def test_event_handler_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_exception) - transport.on_mqtt_connected_handler = event_cb - - transport.connect(fake_password) - mock_mqtt_client.on_connect(client=mock_mqtt_client, userdata=None, flags=None, rc=fake_rc) - - # Callback was called, but exception did not propagate - assert event_cb.call_count == 1 - - @pytest.mark.it( - "Allows any BaseExceptions raised in on_mqtt_connected_handler event handler to propagate" - ) - def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_base_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) - transport.on_mqtt_connected_handler = event_cb - - transport.connect(fake_password) - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=fake_rc - ) - assert e_info.value is arbitrary_base_exception - - -@pytest.mark.describe("MQTTTransport - OCCURRENCE: Connection Failure") -class TestEventConnectionFailure(object): - @pytest.mark.parametrize( - "error_params", - connack_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in connack_return_codes], - ) - @pytest.mark.it( - "Triggers on_mqtt_connection_failure_handler event handler with custom Exception upon failed connect completion" - ) - def test_calls_event_handler_callback_with_failed_rc( - self, mocker, mock_mqtt_client, transport, error_params - ): - callback = mocker.MagicMock() - transport.on_mqtt_connection_failure_handler = callback - - # Initiate connect - transport.connect(fake_password) - - # Manually trigger Paho on_connect event_handler - mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=error_params["rc"] - ) - - # Verify transport.on_mqtt_connection_failure_handler was called - assert callback.call_count == 1 - assert isinstance(callback.call_args[0][0], error_params["error"]) - - @pytest.mark.it( - "Skips on_mqtt_connection_failure_handler event handler if set to 'None' upon failed connect completion" - ) - def test_skips_none_event_handler_callback(self, mocker, mock_mqtt_client, transport): - assert transport.on_mqtt_connection_failure_handler is None - - transport.connect(fake_password) - - mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=failed_connack_rc - ) - - # No further asserts required - this is a test to show that it skips a callback. - # Not raising an exception == test passed - - @pytest.mark.it("Recovers from Exception in on_mqtt_connection_failure_handler event handler") - def test_event_handler_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_exception) - transport.on_mqtt_connection_failure_handler = event_cb - - transport.connect(fake_password) - mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=failed_connack_rc - ) - - # Callback was called, but exception did not propagate - assert event_cb.call_count == 1 - - @pytest.mark.it( - "Allows any BaseExceptions raised in on_mqtt_connection_failure_handler event handler to propagate" - ) - def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_base_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) - transport.on_mqtt_connection_failure_handler = event_cb - - transport.connect(fake_password) - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=failed_connack_rc - ) - assert e_info.value is arbitrary_base_exception - - -@pytest.mark.describe("MQTTTransport - .disconnect()") -class TestDisconnect(object): - @pytest.mark.it("Initiates MQTT disconnect via Paho") - def test_calls_paho_disconnect(self, mocker, mock_mqtt_client, transport): - transport.disconnect() - - assert mock_mqtt_client.disconnect.call_count == 1 - assert mock_mqtt_client.disconnect.call_args == mocker.call() - - @pytest.mark.it( - "Raises a ProtocolClientError if Paho disconnect raises an unexpected Exception" - ) - def test_client_raises_unexpected_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - mock_mqtt_client.disconnect.side_effect = arbitrary_exception - with pytest.raises(errors.ProtocolClientError) as e_info: - transport.disconnect() - assert e_info.value.__cause__ is arbitrary_exception - - @pytest.mark.it("Allows any BaseExceptions raised in Paho disconnect to propagate") - def test_client_raises_base_exception( - self, mock_mqtt_client, transport, arbitrary_base_exception - ): - mock_mqtt_client.disconnect.side_effect = arbitrary_base_exception - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.disconnect() - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Raises a custom Exception if Paho disconnect returns a failing rc code") - @pytest.mark.parametrize( - "error_params", - operation_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], - ) - def test_client_returns_failing_rc_code( - self, mocker, mock_mqtt_client, transport, error_params - ): - mock_mqtt_client.disconnect.return_value = error_params["rc"] - with pytest.raises(error_params["error"]): - transport.disconnect() - - @pytest.mark.it("Cancels all pending operations if the clear_inflight parameter is True") - def test_pending_op_cancellation(self, mocker, mock_mqtt_client, transport): - # Set up a pending publish - pub_callback = mocker.MagicMock(name="pub cb") - pub_mid = "1" - message_info = mqtt.MQTTMessageInfo(pub_mid) - message_info.rc = fake_rc - mock_mqtt_client.publish.return_value = message_info - transport.publish(topic=fake_topic, payload=fake_payload, callback=pub_callback) - - # Set up a pending subscribe - sub_callback = mocker.MagicMock(name="sub_cb") - sub_mid = "2" - mock_mqtt_client.subscribe.return_value = (fake_rc, sub_mid) - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=sub_callback) - - # Operations are pending - assert pub_callback.call_count == 0 - assert sub_callback.call_count == 0 - - # Disconnect and clear pending ops - transport.disconnect(clear_inflight=True) - - # Pending operations were cancelled - assert pub_callback.call_count == 1 - assert pub_callback.call_args == mocker.call(cancelled=True) - assert sub_callback.call_count == 1 - assert sub_callback.call_args == mocker.call(cancelled=True) - - @pytest.mark.it( - "Does not cancel any pending operations if the clear_inflight parameter is False" - ) - def test_no_pending_op_cancellation(self, mocker, mock_mqtt_client, transport): - # Set up a pending publish - pub_callback = mocker.MagicMock(name="pub cb") - pub_mid = "1" - message_info = mqtt.MQTTMessageInfo(pub_mid) - message_info.rc = fake_rc - mock_mqtt_client.publish.return_value = message_info - transport.publish(topic=fake_topic, payload=fake_payload, callback=pub_callback) - - # Set up a pending subscribe - sub_callback = mocker.MagicMock(name="sub_cb") - sub_mid = "2" - mock_mqtt_client.subscribe.return_value = (fake_rc, sub_mid) - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=sub_callback) - - # Operations are pending - assert pub_callback.call_count == 0 - assert sub_callback.call_count == 0 - - # Disconnect - transport.disconnect(clear_inflight=False) - - # No pending operations were cancelled - assert pub_callback.call_count == 0 - assert sub_callback.call_count == 0 - - @pytest.mark.it( - "Does not cancel any pending operations if the clear_inflight parameter is not provided" - ) - def test_default_no_pending_op_cancellation(self, mocker, mock_mqtt_client, transport): - # Set up a pending publish - pub_callback = mocker.MagicMock(name="pub cb") - pub_mid = "1" - message_info = mqtt.MQTTMessageInfo(pub_mid) - message_info.rc = fake_rc - mock_mqtt_client.publish.return_value = message_info - transport.publish(topic=fake_topic, payload=fake_payload, callback=pub_callback) - - # Set up a pending subscribe - sub_callback = mocker.MagicMock(name="sub_cb") - sub_mid = "2" - mock_mqtt_client.subscribe.return_value = (fake_rc, sub_mid) - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=sub_callback) - - # Operations are pending - assert pub_callback.call_count == 0 - assert sub_callback.call_count == 0 - - # Disconnect - transport.disconnect() - - # No pending operations were cancelled - assert pub_callback.call_count == 0 - assert sub_callback.call_count == 0 - - @pytest.mark.it("Stops MQTT Network Loop when disconnect does not raise an exception") - def test_calls_loop_stop_on_success(self, mocker, mock_mqtt_client, transport): - transport.disconnect() - - assert mock_mqtt_client.loop_stop.call_count == 1 - assert mock_mqtt_client.loop_stop.call_args == mocker.call() - - @pytest.mark.it("Stops MQTT Network Loop when disconnect raises an exception") - def test_calls_loop_stop_on_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - mock_mqtt_client.disconnect.side_effect = arbitrary_exception - - with pytest.raises(Exception): - transport.disconnect() - - assert mock_mqtt_client.loop_stop.call_count == 1 - assert mock_mqtt_client.loop_stop.call_args == mocker.call() - - @pytest.mark.it( - "Sets Paho's _thread to None if disconnect does not raise an exception while running in the Paho thread" - ) - def test_sets_thread_to_none_on_success_in_paho_thread( - self, mocker, mock_mqtt_client, transport, mock_paho_thread_current - ): - transport.disconnect() - assert mock_mqtt_client._thread is None - - @pytest.mark.it( - "Sets Paho's _thread to None if disconnect raises an exception while running in the Paho thread" - ) - def test_sets_thread_to_none_on_exception_in_paho_thread( - self, mocker, mock_mqtt_client, transport, arbitrary_exception, mock_paho_thread_current - ): - mock_mqtt_client.disconnect.side_effect = arbitrary_exception - - with pytest.raises(Exception): - transport.disconnect() - assert mock_mqtt_client._thread is None - - @pytest.mark.it( - "Does not set Paho's _thread to None if disconnect does not raise an exception while running outside the Paho thread" - ) - def test_does_not_set_thread_to_none_on_success_in_non_paho_thread( - self, mocker, mock_mqtt_client, transport, mock_non_paho_thread_current - ): - transport.disconnect() - assert mock_mqtt_client._thread is not None - - @pytest.mark.it( - "Does not set Paho's _thread to None if disconnect raises an exception while running outside the Paho thread" - ) - def test_does_not_set_thread_to_none_on_exception_in_non_paho_thread( - self, mocker, mock_mqtt_client, transport, arbitrary_exception, mock_non_paho_thread_current - ): - mock_mqtt_client.disconnect.side_effect = arbitrary_exception - - with pytest.raises(Exception): - transport.disconnect() - assert mock_mqtt_client._thread is not None - - -@pytest.mark.describe("MQTTTransport - OCCURRENCE: Disconnect Completed") -class TestEventDisconnectCompleted(object): - @pytest.fixture - def collected_transport_weakref(self, mock_mqtt_client): - # return a weak reference to an MQTTTransport that has already been collected - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - transport_weakref = weakref.ref(transport) - transport = None - gc.collect(2) # 2 == collect as much as possible - assert transport_weakref() is None - return transport_weakref - - @pytest.fixture( - params=[fake_success_rc, fake_failed_rc], ids=["success rc code", "failed rc code"] - ) - def rc_success_or_failure(self, request): - return request.param - - @pytest.mark.it( - "Triggers on_mqtt_disconnected_handler event handler upon disconnect completion" - ) - def test_calls_event_handler_callback_externally_driven( - self, mocker, mock_mqtt_client, transport - ): - callback = mocker.MagicMock() - transport.on_mqtt_disconnected_handler = callback - - # Initiate disconnect - transport.disconnect() - - # Manually trigger Paho on_connect event_handler - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_rc) - - # Verify transport.on_mqtt_connected_handler was called - assert callback.call_count == 1 - assert callback.call_args == mocker.call(None) - - @pytest.mark.parametrize( - "error_params", - operation_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], - ) - @pytest.mark.it( - "Triggers on_mqtt_disconnected_handler event handler with custom Exception when an error RC is returned upon disconnect completion." - ) - def test_calls_event_handler_callback_with_failure_user_driven( - self, mocker, mock_mqtt_client, transport, error_params - ): - callback = mocker.MagicMock() - transport.on_mqtt_disconnected_handler = callback - - # Initiate disconnect - transport.disconnect() - - # Manually trigger Paho on_disconnect event_handler - mock_mqtt_client.on_disconnect( - client=mock_mqtt_client, userdata=None, rc=error_params["rc"] - ) - - # Verify transport.on_mqtt_disconnected_handler was called - assert callback.call_count == 1 - assert isinstance(callback.call_args[0][0], error_params["error"]) - - @pytest.mark.it( - "Skips on_mqtt_disconnected_handler event handler if set to 'None' upon disconnect completion" - ) - def test_skips_none_event_handler_callback(self, mocker, mock_mqtt_client, transport): - assert transport.on_mqtt_disconnected_handler is None - - transport.disconnect() - - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_rc) - - # No further asserts required - this is a test to show that it skips a callback. - # Not raising an exception == test passed - - @pytest.mark.it("Recovers from Exception in on_mqtt_disconnected_handler event handler") - def test_event_handler_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_exception) - transport.on_mqtt_disconnected_handler = event_cb - - transport.disconnect() - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_rc) - - # Callback was called, but exception did not propagate - assert event_cb.call_count == 1 - - @pytest.mark.it( - "Allows any BaseExceptions raised in on_mqtt_disconnected_handler event handler to propagate" - ) - def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_base_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) - transport.on_mqtt_disconnected_handler = event_cb - - transport.disconnect() - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_rc) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Calls Paho's disconnect() method if cause is not None") - def test_calls_disconnect_with_cause(self, mock_mqtt_client, transport): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_failed_rc) - assert mock_mqtt_client.disconnect.call_count == 1 - - @pytest.mark.it("Does not call Paho's disconnect() method if cause is None") - def test_doesnt_call_disconnect_without_cause(self, mock_mqtt_client, transport): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_success_rc) - assert mock_mqtt_client.disconnect.call_count == 0 - - @pytest.mark.it("Calls Paho's loop_stop() if cause is not None") - def test_calls_loop_stop(self, mock_mqtt_client, transport): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_failed_rc) - assert mock_mqtt_client.loop_stop.call_count == 1 - - @pytest.mark.it("Does not calls Paho's loop_stop() if cause is None") - def test_does_not_call_loop_stop(self, mock_mqtt_client, transport): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_success_rc) - assert mock_mqtt_client.loop_stop.call_count == 0 - - @pytest.mark.it( - "Sets Paho's _thread to None if cause is not None while running in the Paho thread" - ) - def test_sets_thread_to_none_on_failure_in_paho_thread( - self, mock_mqtt_client, transport, mock_paho_thread_current - ): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_failed_rc) - assert mock_mqtt_client._thread is None - - @pytest.mark.it( - "Does not set Paho's _thread to None if cause is not None while running outside the paho thread" - ) - def test_sets_thread_to_none_on_failure_in_non_paho_thread( - self, mock_mqtt_client, transport, mock_non_paho_thread_current - ): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_failed_rc) - assert mock_mqtt_client._thread is not None - - @pytest.mark.it( - "Does not sets Paho's _thread to None if cause is None while running in the Paho thread" - ) - def test_does_not_set_thread_to_none_on_success_in_paho_thread( - self, mock_mqtt_client, transport, mock_paho_thread_current - ): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_success_rc) - assert mock_mqtt_client._thread is not None - - @pytest.mark.it( - "Does not sets Paho's _thread to None if cause is None while running outside the Paho thread" - ) - def test_does_not_set_thread_to_none_on_success_in_non_paho_thread( - self, mock_mqtt_client, transport, mock_non_paho_thread_current - ): - mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_success_rc) - assert mock_mqtt_client._thread is not None - - @pytest.mark.it("Allows any Exception raised by Paho's disconnect() to propagate") - def test_disconnect_raises_exception( - self, mock_mqtt_client, transport, mocker, arbitrary_exception - ): - mock_mqtt_client.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) - with pytest.raises(type(arbitrary_exception)) as e_info: - mock_mqtt_client.on_disconnect( - client=mock_mqtt_client, userdata=None, rc=fake_failed_rc - ) - assert e_info.value is arbitrary_exception - - @pytest.mark.it("Allows any BaseException raised by Paho's disconnect() to propagate") - def test_disconnect_raises_base_exception( - self, mock_mqtt_client, transport, mocker, arbitrary_base_exception - ): - mock_mqtt_client.disconnect = mocker.MagicMock(side_effect=arbitrary_base_exception) - with pytest.raises(type(arbitrary_base_exception)) as e_info: - mock_mqtt_client.on_disconnect( - client=mock_mqtt_client, userdata=None, rc=fake_failed_rc - ) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Allows any Exception raised by Paho's loop_stop() to propagate") - def test_loop_stop_raises_exception( - self, mock_mqtt_client, transport, mocker, arbitrary_exception - ): - mock_mqtt_client.loop_stop = mocker.MagicMock(side_effect=arbitrary_exception) - with pytest.raises(type(arbitrary_exception)) as e_info: - mock_mqtt_client.on_disconnect( - client=mock_mqtt_client, userdata=None, rc=fake_failed_rc - ) - assert e_info.value is arbitrary_exception - - @pytest.mark.it("Allows any BaseException raised by Paho's loop_stop() to propagate") - def test_loop_stop_raises_base_exception( - self, mock_mqtt_client, transport, mocker, arbitrary_base_exception - ): - mock_mqtt_client.loop_stop = mocker.MagicMock(side_effect=arbitrary_base_exception) - with pytest.raises(type(arbitrary_base_exception)) as e_info: - mock_mqtt_client.on_disconnect( - client=mock_mqtt_client, userdata=None, rc=fake_failed_rc - ) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it( - "Does not raise any exceptions if the MQTTTransport object was garbage collected before the disconnect completed" - ) - def test_no_exception_after_gc( - self, mock_mqtt_client, collected_transport_weakref, rc_success_or_failure - ): - assert mock_mqtt_client.on_disconnect - mock_mqtt_client.on_disconnect(mock_mqtt_client, None, rc_success_or_failure) - # lack of exception is success - - @pytest.mark.it( - "Calls Paho's loop_stop() if the MQTTTransport object was garbage collected before the disconnect completed" - ) - def test_calls_loop_stop_after_gc( - self, collected_transport_weakref, mock_mqtt_client, rc_success_or_failure, mocker - ): - assert mock_mqtt_client.loop_stop.call_count == 0 - mock_mqtt_client.on_disconnect(mock_mqtt_client, None, rc_success_or_failure) - assert mock_mqtt_client.loop_stop.call_count == 1 - assert mock_mqtt_client.loop_stop.call_args == mocker.call() - - @pytest.mark.it( - "Allows any Exception raised by Paho's loop_stop() to propagate if the MQTTTransport object was garbage collected before the disconnect completed" - ) - def test_raises_exception_after_gc( - self, - collected_transport_weakref, - mock_mqtt_client, - rc_success_or_failure, - arbitrary_exception, - ): - mock_mqtt_client.loop_stop.side_effect = arbitrary_exception - with pytest.raises(type(arbitrary_exception)): - mock_mqtt_client.on_disconnect(mock_mqtt_client, None, rc_success_or_failure) - - @pytest.mark.it( - "Allows any BaseException raised by Paho's loop_stop() to propagate if the MQTTTransport object was garbage collected before the disconnect completed" - ) - def test_raises_base_exception_after_gc( - self, - collected_transport_weakref, - mock_mqtt_client, - rc_success_or_failure, - arbitrary_base_exception, - ): - mock_mqtt_client.loop_stop.side_effect = arbitrary_base_exception - with pytest.raises(type(arbitrary_base_exception)): - mock_mqtt_client.on_disconnect(mock_mqtt_client, None, rc_success_or_failure) - - -@pytest.mark.describe("MQTTTransport - .subscribe()") -class TestSubscribe(object): - @pytest.mark.it("Subscribes with Paho") - @pytest.mark.parametrize( - "qos", - [pytest.param(0, id="QoS 0"), pytest.param(1, id="QoS 1"), pytest.param(2, id="QoS 2")], - ) - def test_calls_paho_subscribe(self, mocker, mock_mqtt_client, transport, qos): - transport.subscribe(fake_topic, qos=qos) - - assert mock_mqtt_client.subscribe.call_count == 1 - assert mock_mqtt_client.subscribe.call_args == mocker.call(fake_topic, qos=qos) - - @pytest.mark.it("Raises ValueError on invalid QoS") - @pytest.mark.parametrize("qos", [pytest.param(-1, id="QoS < 0"), pytest.param(3, id="QoS > 2")]) - def test_raises_value_error_invalid_qos(self, qos): - # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - with pytest.raises(ValueError): - transport.subscribe(fake_topic, qos=qos) - - @pytest.mark.it("Raises ValueError on invalid topic string") - @pytest.mark.parametrize("topic", [pytest.param(None), pytest.param("", id="Empty string")]) - def test_raises_value_error_invalid_topic(self, topic): - # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - with pytest.raises(ValueError): - transport.subscribe(topic, qos=fake_qos) - - @pytest.mark.it("Triggers callback upon subscribe completion") - def test_triggers_callback_upon_paho_on_subscribe_event( - self, mocker, mock_mqtt_client, transport - ): - callback = mocker.MagicMock() - mock_mqtt_client.subscribe.return_value = (fake_rc, fake_mid) - - # Initiate subscribe - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - - # Check callback is not called yet - assert callback.call_count == 0 - - # Manually trigger Paho on_subscribe event handler - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - - # Check callback has now been called - assert callback.call_count == 1 - - @pytest.mark.it( - "Triggers callback upon subscribe completion when Paho event handler triggered early" - ) - def test_triggers_callback_when_paho_on_subscribe_event_called_early( - self, mocker, mock_mqtt_client, transport - ): - callback = mocker.MagicMock() - - def trigger_early_on_subscribe(topic, qos): - - # Trigger on_subscribe before returning mid - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - - # Check callback not yet called - assert callback.call_count == 0 - - return (fake_rc, fake_mid) - - mock_mqtt_client.subscribe.side_effect = trigger_early_on_subscribe - - # Initiate subscribe - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - - # Check callback has now been called - assert callback.call_count == 1 - - @pytest.mark.it("Skips callback that is set to 'None' upon subscribe completion") - def test_none_callback_upon_paho_on_subscribe_event(self, mocker, mock_mqtt_client, transport): - callback = None - mock_mqtt_client.subscribe.return_value = (fake_rc, fake_mid) - - # Initiate subscribe - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - - # Manually trigger Paho on_subscribe event handler - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - - # No assertions necessary - not raising an exception => success - - @pytest.mark.it( - "Skips callback that is set to 'None' upon subscribe completion when Paho event handler triggered early" - ) - def test_none_callback_when_paho_on_subscribe_event_called_early( - self, mocker, mock_mqtt_client, transport - ): - callback = None - - def trigger_early_on_subscribe(topic, qos): - - # Trigger on_subscribe before returning mid - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - - return (fake_rc, fake_mid) - - mock_mqtt_client.subscribe.side_effect = trigger_early_on_subscribe - - # Initiate subscribe - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - - # No assertions necessary - not raising an exception => success - - @pytest.mark.it( - "Handles multiple callbacks from multiple subscribe operations that complete out of order" - ) - def test_multiple_callbacks(self, mocker, mock_mqtt_client, transport): - callback1 = mocker.MagicMock() - callback2 = mocker.MagicMock() - callback3 = mocker.MagicMock() - - mid1 = 1 - mid2 = 2 - mid3 = 3 - - mock_mqtt_client.subscribe.side_effect = [(fake_rc, mid1), (fake_rc, mid2), (fake_rc, mid3)] - - # Initiate subscribe (1 -> 2 -> 3) - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback1) - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback2) - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback3) - - # Check callbacks have not yet been called - assert callback1.call_count == 0 - assert callback2.call_count == 0 - assert callback3.call_count == 0 - - # Manually trigger Paho on_subscribe event handler (2 -> 3 -> 1) - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=mid2, granted_qos=fake_qos - ) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 0 - - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=mid3, granted_qos=fake_qos - ) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=mid1, granted_qos=fake_qos - ) - assert callback1.call_count == 1 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - @pytest.mark.it("Recovers from Exception in callback") - def test_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_exception) - mock_mqtt_client.subscribe.return_value = (fake_rc, fake_mid) - - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - - # Callback was called, but exception did not propagate - assert callback.call_count == 1 - - @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_base_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_base_exception) - mock_mqtt_client.subscribe.return_value = (fake_rc, fake_mid) - - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Recovers from Exception in callback when Paho event handler triggered early") - def test_callback_raises_exception_when_paho_on_subscribe_triggered_early( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_exception) - - def trigger_early_on_subscribe(topic, qos): - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - - # Should not have yet called callback - assert callback.call_count == 0 - - return (fake_rc, fake_mid) - - mock_mqtt_client.subscribe.side_effect = trigger_early_on_subscribe - - # Initiate subscribe - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - - # Callback was called, but exception did not propagate - assert callback.call_count == 1 - - @pytest.mark.it( - "Allows any BaseExceptions raised in callback when Paho event handler triggered early to propagate" - ) - def test_callback_raises_base_exception_when_paho_on_subscribe_triggered_early( - self, mocker, mock_mqtt_client, transport, arbitrary_base_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_base_exception) - - def trigger_early_on_subscribe(topic, qos): - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos - ) - - # Should not have yet called callback - assert callback.call_count == 0 - - return (fake_rc, fake_mid) - - mock_mqtt_client.subscribe.side_effect = trigger_early_on_subscribe - - # Initiate subscribe - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Raises a ProtocolClientError if Paho subscribe raises an unexpected Exception") - def test_client_raises_unexpected_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - mock_mqtt_client.subscribe.side_effect = arbitrary_exception - with pytest.raises(errors.ProtocolClientError) as e_info: - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=None) - assert e_info.value.__cause__ is arbitrary_exception - - @pytest.mark.it("Allows any BaseExceptions raised in Paho subscribe to propagate") - def test_client_raises_base_exception( - self, mock_mqtt_client, transport, arbitrary_base_exception - ): - mock_mqtt_client.subscribe.side_effect = arbitrary_base_exception - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=None) - assert e_info.value is arbitrary_base_exception - - # NOTE: this test tests for all possible return codes, even ones that shouldn't be - # possible on a subscribe operation. - @pytest.mark.it("Raises a custom Exception if Paho subscribe returns a failing rc code") - @pytest.mark.parametrize( - "error_params", - operation_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], - ) - def test_client_returns_failing_rc_code( - self, mocker, mock_mqtt_client, transport, error_params - ): - mock_mqtt_client.subscribe.return_value = (error_params["rc"], 0) - with pytest.raises(error_params["error"]): - transport.subscribe(topic=fake_topic, qos=fake_qos, callback=None) - - -@pytest.mark.describe("MQTTTransport - .unsubscribe()") -class TestUnsubscribe(object): - @pytest.mark.it("Unsubscribes with Paho") - def test_calls_paho_unsubscribe(self, mocker, mock_mqtt_client, transport): - transport.unsubscribe(fake_topic) - - assert mock_mqtt_client.unsubscribe.call_count == 1 - assert mock_mqtt_client.unsubscribe.call_args == mocker.call(fake_topic) - - @pytest.mark.it("Raises ValueError on invalid topic string") - @pytest.mark.parametrize("topic", [pytest.param(None), pytest.param("", id="Empty string")]) - def test_raises_value_error_invalid_topic(self, topic): - # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - with pytest.raises(ValueError): - transport.unsubscribe(topic) - - @pytest.mark.it("Triggers callback upon unsubscribe completion") - def test_triggers_callback_upon_paho_on_unsubscribe_event( - self, mocker, mock_mqtt_client, transport - ): - callback = mocker.MagicMock() - mock_mqtt_client.unsubscribe.return_value = (fake_rc, fake_mid) - - # Initiate unsubscribe - transport.unsubscribe(topic=fake_topic, callback=callback) - - # Check callback not called - assert callback.call_count == 0 - - # Manually trigger Paho on_unsubscribe event handler - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - - # Check callback has now been called - assert callback.call_count == 1 - - @pytest.mark.it( - "Triggers callback upon unsubscribe completion when Paho event handler triggered early" - ) - def test_triggers_callback_when_paho_on_unsubscribe_event_called_early( - self, mocker, mock_mqtt_client, transport - ): - callback = mocker.MagicMock() - - def trigger_early_on_unsubscribe(topic): - - # Trigger on_unsubscribe before returning mid - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - - # Check callback not yet called - assert callback.call_count == 0 - - return (fake_rc, fake_mid) - - mock_mqtt_client.unsubscribe.side_effect = trigger_early_on_unsubscribe - - # Initiate unsubscribe - transport.unsubscribe(topic=fake_topic, callback=callback) - - # Check callback has now been called - assert callback.call_count == 1 - - @pytest.mark.it("Skips callback that is set to 'None' upon unsubscribe completion") - def test_none_callback_upon_paho_on_unsubscribe_event( - self, mocker, mock_mqtt_client, transport - ): - callback = None - mock_mqtt_client.unsubscribe.return_value = (fake_rc, fake_mid) - - # Initiate unsubscribe - transport.unsubscribe(topic=fake_topic, callback=callback) - - # Manually trigger Paho on_unsubscribe event handler - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - - # No assertions necessary - not raising an exception => success - - @pytest.mark.it( - "Skips callback that is set to 'None' upon unsubscribe completion when Paho event handler triggered early" - ) - def test_none_callback_when_paho_on_unsubscribe_event_called_early( - self, mocker, mock_mqtt_client, transport - ): - callback = None - - def trigger_early_on_unsubscribe(topic): - - # Trigger on_unsubscribe before returning mid - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - - return (fake_rc, fake_mid) - - mock_mqtt_client.unsubscribe.side_effect = trigger_early_on_unsubscribe - - # Initiate unsubscribe - transport.unsubscribe(topic=fake_topic, callback=callback) - - # No assertions necessary - not raising an exception => success - - @pytest.mark.it( - "Handles multiple callbacks from multiple unsubscribe operations that complete out of order" - ) - def test_multiple_callbacks(self, mocker, mock_mqtt_client, transport): - callback1 = mocker.MagicMock() - callback2 = mocker.MagicMock() - callback3 = mocker.MagicMock() - - mid1 = 1 - mid2 = 2 - mid3 = 3 - - mock_mqtt_client.unsubscribe.side_effect = [ - (fake_rc, mid1), - (fake_rc, mid2), - (fake_rc, mid3), - ] - - # Initiate unsubscribe (1 -> 2 -> 3) - transport.unsubscribe(topic=fake_topic, callback=callback1) - transport.unsubscribe(topic=fake_topic, callback=callback2) - transport.unsubscribe(topic=fake_topic, callback=callback3) - - # Check callbacks have not yet been called - assert callback1.call_count == 0 - assert callback2.call_count == 0 - assert callback3.call_count == 0 - - # Manually trigger Paho on_unsubscribe event handler (2 -> 3 -> 1) - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=mid2) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 0 - - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=mid3) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=mid1) - assert callback1.call_count == 1 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - @pytest.mark.it("Recovers from Exception in callback") - def test_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_exception) - mock_mqtt_client.unsubscribe.return_value = (fake_rc, fake_mid) - - transport.unsubscribe(topic=fake_topic, callback=callback) - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - - # Callback was called, but exception did not propagate - assert callback.call_count == 1 - - @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, arbitrary_base_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_base_exception) - mock_mqtt_client.unsubscribe.return_value = (fake_rc, fake_mid) - - transport.unsubscribe(topic=fake_topic, callback=callback) - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Recovers from Exception in callback when Paho event handler triggered early") - def test_callback_raises_exception_when_paho_on_unsubscribe_triggered_early( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_exception) - - def trigger_early_on_unsubscribe(topic): - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - - # Should not have yet called callback - assert callback.call_count == 0 - - return (fake_rc, fake_mid) - - mock_mqtt_client.unsubscribe.side_effect = trigger_early_on_unsubscribe - - # Initiate unsubscribe - transport.unsubscribe(topic=fake_topic, callback=callback) - - # Callback was called, but exception did not propagate - assert callback.call_count == 1 - - @pytest.mark.it( - "Allows any BaseExceptions raised in callback when Paho event handler triggered early to propagate" - ) - def test_callback_raises_base_exception_when_paho_on_unsubscribe_triggered_early( - self, mocker, mock_mqtt_client, transport, arbitrary_base_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_base_exception) - - def trigger_early_on_unsubscribe(topic): - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) - - # Should not have yet called callback - assert callback.call_count == 0 - - return (fake_rc, fake_mid) - - mock_mqtt_client.unsubscribe.side_effect = trigger_early_on_unsubscribe - - # Initiate unsubscribe - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.unsubscribe(topic=fake_topic, callback=callback) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it( - "Raises a ProtocolClientError if Paho unsubscribe raises an unexpected Exception" - ) - def test_client_raises_unexpected_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - mock_mqtt_client.unsubscribe.side_effect = arbitrary_exception - with pytest.raises(errors.ProtocolClientError) as e_info: - transport.unsubscribe(topic=fake_topic, callback=None) - assert e_info.value.__cause__ is arbitrary_exception - - @pytest.mark.it("Allows any BaseExceptions raised in Paho unsubscribe to propagate") - def test_client_raises_base_exception( - self, mock_mqtt_client, transport, arbitrary_base_exception - ): - mock_mqtt_client.unsubscribe.side_effect = arbitrary_base_exception - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.unsubscribe(topic=fake_topic, callback=None) - assert e_info.value is arbitrary_base_exception - - # NOTE: this test tests for all possible return codes, even ones that shouldn't be - # possible on an unsubscribe operation. - @pytest.mark.it("Raises a custom Exception if Paho unsubscribe returns a failing rc code") - @pytest.mark.parametrize( - "error_params", - operation_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], - ) - def test_client_returns_failing_rc_code( - self, mocker, mock_mqtt_client, transport, error_params - ): - mock_mqtt_client.unsubscribe.return_value = (error_params["rc"], 0) - with pytest.raises(error_params["error"]): - transport.unsubscribe(topic=fake_topic, callback=None) - - -@pytest.mark.describe("MQTTTransport - .publish()") -class TestPublish(object): - @pytest.fixture - def message_info(self, mocker): - mi = mqtt.MQTTMessageInfo(fake_mid) - mi.rc = fake_rc - return mi - - @pytest.mark.it("Publishes with Paho") - @pytest.mark.parametrize( - "qos", - [pytest.param(0, id="QoS 0"), pytest.param(1, id="QoS 1"), pytest.param(2, id="QoS 2")], - ) - def test_calls_paho_publish(self, mocker, mock_mqtt_client, transport, qos): - transport.publish(topic=fake_topic, payload=fake_payload, qos=qos) - - assert mock_mqtt_client.publish.call_count == 1 - assert mock_mqtt_client.publish.call_args == mocker.call( - topic=fake_topic, payload=fake_payload, qos=qos - ) - - @pytest.mark.it("Raises ValueError on invalid QoS") - @pytest.mark.parametrize("qos", [pytest.param(-1, id="QoS < 0"), pytest.param(3, id="Qos > 2")]) - def test_raises_value_error_invalid_qos(self, qos): - # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - with pytest.raises(ValueError): - transport.publish(topic=fake_topic, payload=fake_payload, qos=qos) - - @pytest.mark.it("Raises ValueError on invalid topic string") - @pytest.mark.parametrize( - "topic", - [ - pytest.param(None), - pytest.param("", id="Empty string"), - pytest.param("+", id="Contains wildcard (+)"), - ], - ) - def test_raises_value_error_invalid_topic(self, topic): - # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - with pytest.raises(ValueError): - transport.publish(topic=topic, payload=fake_payload, qos=fake_qos) - - @pytest.mark.it("Raises ValueError on invalid payload value") - @pytest.mark.parametrize("payload", [str(b"0" * 268435456)], ids=["Payload > 268435455 bytes"]) - def test_raises_value_error_invalid_payload(self, payload): - # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - with pytest.raises(ValueError): - transport.publish(topic=fake_topic, payload=payload, qos=fake_qos) - - @pytest.mark.it("Raises TypeError on invalid payload type") - @pytest.mark.parametrize( - "payload", - [ - pytest.param({"a": "b"}, id="Dictionary"), - pytest.param([1, 2, 3], id="List"), - pytest.param(object(), id="Object"), - ], - ) - def test_raises_type_error_invalid_payload_type(self, payload): - # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) - transport = MQTTTransport( - client_id=fake_device_id, hostname=fake_hostname, username=fake_username - ) - with pytest.raises(TypeError): - transport.publish(topic=fake_topic, payload=payload, qos=fake_qos) - - @pytest.mark.it("Triggers callback upon publish completion") - def test_triggers_callback_upon_paho_on_publish_event( - self, mocker, mock_mqtt_client, transport, message_info - ): - callback = mocker.MagicMock() - mock_mqtt_client.publish.return_value = message_info - - # Initiate publish - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - - # Check callback is not called - assert callback.call_count == 0 - - # Manually trigger Paho on_publish event handler - mock_mqtt_client.on_publish(client=mock_mqtt_client, userdata=None, mid=message_info.mid) - - # Check callback has now been called - assert callback.call_count == 1 - - @pytest.mark.it( - "Triggers callback upon publish completion when Paho event handler triggered early" - ) - def test_triggers_callback_when_paho_on_publish_event_called_early( - self, mocker, mock_mqtt_client, transport, message_info - ): - callback = mocker.MagicMock() - - def trigger_early_on_publish(topic, payload, qos): - - # Trigger on_publish before returning message_info - mock_mqtt_client.on_publish( - client=mock_mqtt_client, userdata=None, mid=message_info.mid - ) - - # Check callback not yet called - assert callback.call_count == 0 - - return message_info - - mock_mqtt_client.publish.side_effect = trigger_early_on_publish - - # Initiate publish - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - - # Check callback has now been called - assert callback.call_count == 1 - - @pytest.mark.it("Skips callback that is set to 'None' upon publish completion") - def test_none_callback_upon_paho_on_publish_event( - self, mocker, mock_mqtt_client, transport, message_info - ): - mock_mqtt_client.publish.return_value = message_info - callback = None - - # Initiate publish - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - - # Manually trigger Paho on_publish event handler - mock_mqtt_client.on_publish(client=mock_mqtt_client, userdata=None, mid=message_info.mid) - - # No assertions necessary - not raising an exception => success - - @pytest.mark.it( - "Skips callback that is set to 'None' upon publish completion when Paho event handler triggered early" - ) - def test_none_callback_when_paho_on_publish_event_called_early( - self, mocker, mock_mqtt_client, transport, message_info - ): - callback = None - - def trigger_early_on_publish(topic, payload, qos): - - # Trigger on_publish before returning message_info - mock_mqtt_client.on_publish( - client=mock_mqtt_client, userdata=None, mid=message_info.mid - ) - - return message_info - - mock_mqtt_client.publish.side_effect = trigger_early_on_publish - - # Initiate publish - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - - # No assertions necessary - not raising an exception => success - - @pytest.mark.it( - "Handles multiple callbacks from multiple publish operations that complete out of order" - ) - def test_multiple_callbacks(self, mocker, mock_mqtt_client, transport): - callback1 = mocker.MagicMock() - callback2 = mocker.MagicMock() - callback3 = mocker.MagicMock() - - mid1 = 1 - mid2 = 2 - mid3 = 3 - - mock_mqtt_client.publish.side_effect = [ - mqtt.MQTTMessageInfo(mid1), - mqtt.MQTTMessageInfo(mid2), - mqtt.MQTTMessageInfo(mid3), - ] - - # Initiate publish (1 -> 2 -> 3) - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback1) - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback2) - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback3) - - # Check callbacks have not yet been called - assert callback1.call_count == 0 - assert callback2.call_count == 0 - assert callback3.call_count == 0 - - # Manually trigger Paho on_publish event handler (2 -> 3 -> 1) - mock_mqtt_client.on_publish(client=mock_mqtt_client, userdata=None, mid=mid2) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 0 - - mock_mqtt_client.on_publish(client=mock_mqtt_client, userdata=None, mid=mid3) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - mock_mqtt_client.on_publish(client=mock_mqtt_client, userdata=None, mid=mid1) - assert callback1.call_count == 1 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - @pytest.mark.it("Recovers from Exception in callback") - def test_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, message_info, arbitrary_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_exception) - mock_mqtt_client.publish.return_value = message_info - - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - mock_mqtt_client.on_publish(client=mock_mqtt_client, userdata=None, mid=message_info.mid) - - # Callback was called, but exception did not propagate - assert callback.call_count == 1 - - @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, message_info, arbitrary_base_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_base_exception) - mock_mqtt_client.publish.return_value = message_info - - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - mock_mqtt_client.on_publish( - client=mock_mqtt_client, userdata=None, mid=message_info.mid - ) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Recovers from Exception in callback when Paho event handler triggered early") - def test_callback_raises_exception_when_paho_on_publish_triggered_early( - self, mocker, mock_mqtt_client, transport, message_info, arbitrary_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_exception) - - def trigger_early_on_publish(topic, payload, qos): - mock_mqtt_client.on_publish( - client=mock_mqtt_client, userdata=None, mid=message_info.mid - ) - - # Should not have yet called callback - assert callback.call_count == 0 - - return message_info - - mock_mqtt_client.publish.side_effect = trigger_early_on_publish - - # Initiate publish - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - - # Callback was called, but exception did not propagate - assert callback.call_count == 1 - - @pytest.mark.it( - "Allows any BaseExceptions raised in callback when Paho event handler triggered early to propagate" - ) - def test_callback_raises_base_exception_when_paho_on_publish_triggered_early( - self, mocker, mock_mqtt_client, transport, message_info, arbitrary_base_exception - ): - callback = mocker.MagicMock(side_effect=arbitrary_base_exception) - - def trigger_early_on_publish(topic, payload, qos): - mock_mqtt_client.on_publish( - client=mock_mqtt_client, userdata=None, mid=message_info.mid - ) - - # Should not have yet called callback - assert callback.call_count == 0 - - return message_info - - mock_mqtt_client.publish.side_effect = trigger_early_on_publish - - # Initiate publish - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Raises a ProtocolClientError if Paho publish raises an unexpected Exception") - def test_client_raises_unexpected_error( - self, mocker, mock_mqtt_client, transport, arbitrary_exception - ): - mock_mqtt_client.publish.side_effect = arbitrary_exception - with pytest.raises(errors.ProtocolClientError) as e_info: - transport.publish(topic=fake_topic, payload=fake_payload, callback=None) - assert e_info.value.__cause__ is arbitrary_exception - - @pytest.mark.it("Allows any BaseExceptions raised in Paho publish to propagate") - def test_client_raises_base_exception( - self, mock_mqtt_client, transport, arbitrary_base_exception - ): - mock_mqtt_client.publish.side_effect = arbitrary_base_exception - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - transport.publish(topic=fake_topic, payload=fake_payload, callback=None) - assert e_info.value is arbitrary_base_exception - - # NOTE: this test tests for all possible return codes, even ones that shouldn't be - # possible on a publish operation. - @pytest.mark.it("Raises a custom Exception if Paho publish returns a failing rc code") - @pytest.mark.parametrize( - "error_params", - operation_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], - ) - def test_client_returns_failing_rc_code( - self, mocker, mock_mqtt_client, transport, error_params - ): - mock_mqtt_client.publish.return_value = (error_params["rc"], 0) - with pytest.raises(error_params["error"]): - transport.publish(topic=fake_topic, payload=fake_payload, callback=None) - - -@pytest.mark.describe("MQTTTransport - OCCURRENCE: Message Received") -class TestMessageReceived(object): - @pytest.fixture() - def message(self): - message = mqtt.MQTTMessage(mid=fake_mid, topic=fake_topic.encode()) - message.payload = fake_payload - message.qos = fake_qos - return message - - @pytest.mark.it( - "Triggers on_mqtt_message_received_handler event handler upon receiving message" - ) - def test_calls_event_handler_callback(self, mocker, mock_mqtt_client, transport, message): - callback = mocker.MagicMock() - transport.on_mqtt_message_received_handler = callback - - # Manually trigger Paho on_message event_handler - mock_mqtt_client.on_message(client=mock_mqtt_client, userdata=None, mqtt_message=message) - - # Verify transport.on_mqtt_message_received_handler was called - assert callback.call_count == 1 - assert callback.call_args == mocker.call(message.topic, message.payload) - - @pytest.mark.it( - "Skips on_mqtt_message_received_handler event handler if set to 'None' upon receiving message" - ) - def test_skips_none_event_handler_callback(self, mocker, mock_mqtt_client, transport, message): - assert transport.on_mqtt_message_received_handler is None - - # Manually trigger Paho on_message event_handler - mock_mqtt_client.on_message(client=mock_mqtt_client, userdata=None, mqtt_message=message) - - # No further asserts required - this is a test to show that it skips a callback. - # Not raising an exception == test passed - - @pytest.mark.it("Recovers from Exception in on_mqtt_message_received_handler event handler") - def test_event_handler_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, message, arbitrary_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_exception) - transport.on_mqtt_message_received_handler = event_cb - - mock_mqtt_client.on_message(client=mock_mqtt_client, userdata=None, mqtt_message=message) - - # Callback was called, but exception did not propagate - assert event_cb.call_count == 1 - - @pytest.mark.it( - "Allows any BaseExceptions raised in on_mqtt_message_received_handler event handler to propagate" - ) - def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, message, arbitrary_base_exception - ): - event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) - transport.on_mqtt_message_received_handler = event_cb - - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - mock_mqtt_client.on_message( - client=mock_mqtt_client, userdata=None, mqtt_message=message - ) - assert e_info.value is arbitrary_base_exception - - -@pytest.mark.describe("MQTTTransport - Misc.") -class TestMisc(object): - @pytest.mark.it( - "Handles multiple callbacks from multiple different types of operations that complete out of order" - ) - def test_multiple_callbacks_multiple_ops(self, mocker, mock_mqtt_client, transport): - callback1 = mocker.MagicMock() - callback2 = mocker.MagicMock() - callback3 = mocker.MagicMock() - - mid1 = 1 - mid2 = 2 - mid3 = 3 - - topic1 = "topic1" - topic2 = "topic2" - topic3 = "topic3" - - mock_mqtt_client.subscribe.return_value = (fake_rc, mid1) - mock_mqtt_client.publish.return_value = mqtt.MQTTMessageInfo(mid2) - mock_mqtt_client.unsubscribe.return_value = (fake_rc, mid3) - - # Initiate operations (1 -> 2 -> 3) - transport.subscribe(topic=topic1, qos=fake_qos, callback=callback1) - transport.publish(topic=topic2, payload="payload", qos=fake_qos, callback=callback2) - transport.unsubscribe(topic=topic3, callback=callback3) - - # Check callbacks have not yet been called - assert callback1.call_count == 0 - assert callback2.call_count == 0 - assert callback3.call_count == 0 - - # Manually trigger Paho on_unsubscribe event handler (2 -> 3 -> 1) - mock_mqtt_client.on_publish(client=mock_mqtt_client, userdata=None, mid=mid2) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 0 - - mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=mid3) - assert callback1.call_count == 0 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - mock_mqtt_client.on_subscribe( - client=mock_mqtt_client, userdata=None, mid=mid1, granted_qos=fake_qos - ) - assert callback1.call_count == 1 - assert callback2.call_count == 1 - assert callback3.call_count == 1 - - -@pytest.mark.describe("OperationManager") -class TestOperationManager(object): - @pytest.mark.it("Instantiates with no operation tracking information") - def test_instantiates_empty(self): - manager = OperationManager() - assert len(manager._pending_operation_callbacks) == 0 - assert len(manager._unknown_operation_completions) == 0 - - -@pytest.mark.describe("OperationManager - .establish_operation()") -class TestOperationManagerEstablishOperation(object): - @pytest.fixture(params=[True, False]) - def optional_callback(self, mocker, request): - if request.param: - return mocker.MagicMock() - else: - return None - - @pytest.mark.it("Begins tracking a pending operation for a new MID") - @pytest.mark.parametrize( - "optional_callback", - [pytest.param(True, id="With callback"), pytest.param(False, id="No callback")], - indirect=True, - ) - def test_no_early_completion(self, optional_callback): - manager = OperationManager() - mid = 1 - manager.establish_operation(mid, optional_callback) - - assert len(manager._pending_operation_callbacks) == 1 - assert manager._pending_operation_callbacks[mid] is optional_callback - - @pytest.mark.it( - "Resolves operation tracking when MID corresponds to a previous unknown completion" - ) - def test_early_completion(self): - manager = OperationManager() - mid = 1 - - # Cause early completion of an unknown operation - manager.complete_operation(mid) - assert len(manager._unknown_operation_completions) == 1 - assert manager._unknown_operation_completions[mid] - - # Establish operation that was already completed - manager.establish_operation(mid) - - assert len(manager._unknown_operation_completions) == 0 - - @pytest.mark.it( - "Triggers the callback if provided when MID corresponds to a previous unknown completion" - ) - def test_early_completion_with_callback(self, mocker): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock() - - # Cause early completion of an unknown operation - manager.complete_operation(mid) - - # Establish operation that was already completed - manager.establish_operation(mid, cb_mock) - - assert cb_mock.call_count == 1 - - @pytest.mark.it("Recovers from Exception thrown in callback") - def test_callback_raises_exception(self, mocker, arbitrary_exception): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock(side_effect=arbitrary_exception) - - # Cause early completion of an unknown operation - manager.complete_operation(mid) - - # Establish operation that was already completed - manager.establish_operation(mid, cb_mock) - - # Callback was called, but exception did not propagate - assert cb_mock.call_count == 1 - - @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception(self, mocker, arbitrary_base_exception): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock(side_effect=arbitrary_base_exception) - - # Cause early completion of an unknown operation - manager.complete_operation(mid) - - # Establish operation that was already completed - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - manager.establish_operation(mid, cb_mock) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Does not trigger the callback until after thread lock has been released") - def test_callback_called_after_lock_release(self, mocker): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock() - - # Cause early completion of an unknown operation - manager.complete_operation(mid) - - # Set up mock tracking - lock_spy = mocker.spy(manager, "_lock") - mock_tracker = mocker.MagicMock() - calls_during_lock = [] - - # When the lock enters, start recording calls to callback - # When the lock exits, copy the list of calls. - - def track_mocks(): - mock_tracker.attach_mock(cb_mock, "cb") - - def stop_tracking_mocks(*args): - local_calls_during_lock = calls_during_lock # do this for python2 compat - local_calls_during_lock += copy.copy(mock_tracker.mock_calls) - mock_tracker.reset_mock() - - lock_spy.__enter__.side_effect = track_mocks - lock_spy.__exit__.side_effect = stop_tracking_mocks - - # Establish operation that was already completed - manager.establish_operation(mid, cb_mock) - - # Callback WAS called, but... - assert cb_mock.call_count == 1 - - # Callback WAS NOT called while the lock was held - assert mocker.call.cb() not in calls_during_lock - - -@pytest.mark.describe("OperationManager - .complete_operation()") -class TestOperationManagerCompleteOperation(object): - @pytest.mark.it("Resolves a operation tracking when MID corresponds to a pending operation") - def test_complete_pending_operation(self): - manager = OperationManager() - mid = 1 - - # Establish a pending operation - manager.establish_operation(mid) - assert len(manager._pending_operation_callbacks) == 1 - - # Complete pending operation - manager.complete_operation(mid) - assert len(manager._pending_operation_callbacks) == 0 - - @pytest.mark.it("Triggers callback for a pending operation when resolving") - def test_complete_pending_operation_callback(self, mocker): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock() - - manager.establish_operation(mid, cb_mock) - assert cb_mock.call_count == 0 - - manager.complete_operation(mid) - assert cb_mock.call_count == 1 - assert cb_mock.call_args == mocker.call() - - @pytest.mark.it("Recovers from Exception thrown in callback") - def test_callback_raises_exception(self, mocker, arbitrary_exception): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock(side_effect=arbitrary_exception) - - manager.establish_operation(mid, cb_mock) - assert cb_mock.call_count == 0 - - manager.complete_operation(mid) - # Callback was called but exception did not propagate - assert cb_mock.call_count == 1 - - @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception(self, mocker, arbitrary_base_exception): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock(side_effect=arbitrary_base_exception) - - manager.establish_operation(mid, cb_mock) - assert cb_mock.call_count == 0 - - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - manager.complete_operation(mid) - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it( - "Begins tracking an unknown completion if MID does not correspond to a pending operation" - ) - def test_early_completion(self): - manager = OperationManager() - mid = 1 - - manager.complete_operation(mid) - assert len(manager._unknown_operation_completions) == 1 - assert manager._unknown_operation_completions[mid] - - @pytest.mark.it("Does not trigger the callback until after thread lock has been released") - def test_callback_called_after_lock_release(self, mocker): - manager = OperationManager() - mid = 1 - cb_mock = mocker.MagicMock() - - # Set up an operation and save the callback - manager.establish_operation(mid, cb_mock) - - # Set up mock tracking - lock_spy = mocker.spy(manager, "_lock") - mock_tracker = mocker.MagicMock() - calls_during_lock = [] - - # When the lock enters, start recording calls to callback - # When the lock exits, copy the list of calls. - - def track_mocks(): - mock_tracker.attach_mock(cb_mock, "cb") - - def stop_tracking_mocks(*args): - local_calls_during_lock = calls_during_lock # do this for python2 compat - local_calls_during_lock += copy.copy(mock_tracker.mock_calls) - mock_tracker.reset_mock() - - lock_spy.__enter__.side_effect = track_mocks - lock_spy.__exit__.side_effect = stop_tracking_mocks - - # Complete the operation - manager.complete_operation(mid) - - # Callback WAS called, but... - assert cb_mock.call_count == 1 - assert cb_mock.call_args == mocker.call() - - # Callback WAS NOT called while the lock was held - assert mocker.call.cb() not in calls_during_lock - - -@pytest.mark.describe("OperationManager - .cancel_all_operations()") -class TestOperationManagerCancelAllOperations(object): - @pytest.mark.it("Removes all MID tracking for all pending operations") - def test_remove_pending_ops(self): - manager = OperationManager() - - # Establish pending operations - manager.establish_operation(mid=1) - manager.establish_operation(mid=2) - manager.establish_operation(mid=3) - assert len(manager._pending_operation_callbacks) == 3 - - # Cancel operations - manager.cancel_all_operations() - assert len(manager._pending_operation_callbacks) == 0 - - @pytest.mark.it("Removes all MID tracking for unknown operation completions") - def test_remove_unknown_completions(self): - manager = OperationManager() - - # Add unknown operation completions - manager.complete_operation(mid=2111) - manager.complete_operation(mid=30045) - manager.complete_operation(mid=2345) - assert len(manager._unknown_operation_completions) == 3 - - # Cancel operations - manager.cancel_all_operations() - assert len(manager._unknown_operation_completions) == 0 - - @pytest.mark.it("Triggers callbacks (if present) with cancel flag for each pending operation") - def test_op_callback_completion(self, mocker): - manager = OperationManager() - - # Establish pending operations - cb_mock1 = mocker.MagicMock() - manager.establish_operation(mid=1, callback=cb_mock1) - cb_mock2 = mocker.MagicMock() - manager.establish_operation(mid=2, callback=cb_mock2) - manager.establish_operation(mid=3, callback=None) - assert cb_mock1.call_count == 0 - assert cb_mock2.call_count == 0 - - # Cancel operations - manager.cancel_all_operations() - assert cb_mock1.call_count == 1 - assert cb_mock1.call_args == mocker.call(cancelled=True) - assert cb_mock2.call_count == 1 - assert cb_mock2.call_args == mocker.call(cancelled=True) - - @pytest.mark.it("Recovers from Exception thrown in callback") - def test_callback_raises_exception(self, mocker, arbitrary_exception): - manager = OperationManager() - - # Establish pending operation - cb_mock = mocker.MagicMock(side_effect=arbitrary_exception) - manager.establish_operation(mid=1, callback=cb_mock) - assert cb_mock.call_count == 0 - - # Cancel operations - manager.cancel_all_operations() - - # Callback was called but exception did not propagate - assert cb_mock.call_count == 1 - - @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception(self, mocker, arbitrary_base_exception): - manager = OperationManager() - - # Establish pending operation - cb_mock = mocker.MagicMock(side_effect=arbitrary_base_exception) - manager.establish_operation(mid=1, callback=cb_mock) - assert cb_mock.call_count == 0 - - # When cancelling operations, Base Exception propagates - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - manager.cancel_all_operations() - assert e_info.value is arbitrary_base_exception - - @pytest.mark.it("Does not trigger callbacks until after thread lock has been released") - def test_callback_called_after_lock_release(self, mocker): - manager = OperationManager() - cb_mock1 = mocker.MagicMock() - cb_mock2 = mocker.MagicMock() - - # Set up operations and save the callback - manager.establish_operation(mid=1, callback=cb_mock1) - manager.establish_operation(mid=2, callback=cb_mock2) - - # Set up mock tracking - lock_spy = mocker.spy(manager, "_lock") - mock_tracker = mocker.MagicMock() - calls_during_lock = [] - - # When the lock enters, start recording calls to callback - # When the lock exits, copy the list of calls. - - def track_mocks(): - mock_tracker.attach_mock(cb_mock1, "cb1") - mock_tracker.attach_mock(cb_mock2, "cb2") - - def stop_tracking_mocks(*args): - local_calls_during_lock = calls_during_lock # do this for python2 compat - local_calls_during_lock += copy.copy(mock_tracker.mock_calls) - mock_tracker.reset_mock() - - lock_spy.__enter__.side_effect = track_mocks - lock_spy.__exit__.side_effect = stop_tracking_mocks - - # Cancel operations - manager.cancel_all_operations() - - # Callbacks WERE called, but... - assert cb_mock1.call_count == 1 - assert cb_mock1.call_args == mocker.call(cancelled=True) - assert cb_mock2.call_count == 1 - assert cb_mock2.call_args == mocker.call(cancelled=True) - - # Callbacks WERE NOT called while the lock was held - assert mocker.call.cb1() not in calls_during_lock - assert mocker.call.cb2() not in calls_during_lock diff --git a/tests/unit/common/pipeline/conftest.py b/tests/unit/conftest.py similarity index 51% rename from tests/unit/common/pipeline/conftest.py rename to tests/unit/conftest.py index d751312ab..b60d7f535 100644 --- a/tests/unit/common/pipeline/conftest.py +++ b/tests/unit/conftest.py @@ -3,12 +3,16 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +""" NOTE: This module will not be necessary anymore once these tests are moved +to the main testing directory""" -from tests.unit.common.pipeline.fixtures import ( # noqa: F401 - arbitrary_event, - arbitrary_op, - fake_pipeline_thread, - fake_non_pipeline_thread, - pipeline_connected_mock, - nucleus, -) +import pytest + + +@pytest.fixture +def arbitrary_exception(): + class ArbitraryException(Exception): + pass + + e = ArbitraryException("arbitrary description") + return e diff --git a/tests/unit/iothub/__init__.py b/tests/unit/iothub/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/iothub/aio/__init__.py b/tests/unit/iothub/aio/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/iothub/aio/test_async_clients.py b/tests/unit/iothub/aio/test_async_clients.py deleted file mode 100644 index 299222883..000000000 --- a/tests/unit/iothub/aio/test_async_clients.py +++ /dev/null @@ -1,2432 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import pytest -import pytest_asyncio -import asyncio -import time -import urllib -from azure.iot.device import exceptions as client_exceptions -from azure.iot.device.common.auth import sastoken as st -from azure.iot.device.iothub.aio import IoTHubDeviceClient, IoTHubModuleClient -from azure.iot.device.iothub.pipeline import constant as pipeline_constant -from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions -from azure.iot.device.iothub.pipeline import IoTHubPipelineConfig -from azure.iot.device.iothub.models import Message, MethodRequest -from azure.iot.device.iothub.abstract_clients import ( - RECEIVE_TYPE_NONE_SET, - RECEIVE_TYPE_HANDLER, - RECEIVE_TYPE_API, -) -from azure.iot.device.iothub.aio.async_inbox import AsyncClientInbox -from azure.iot.device.common import async_adapter -from azure.iot.device import constant as device_constant -from ..shared_client_tests import ( - SharedIoTHubClientInstantiationTests, - SharedIoTHubClientPROPERTYHandlerTests, - SharedIoTHubClientPROPERTYReceiverHandlerTests, - SharedIoTHubClientPROPERTYConnectedTests, - SharedIoTHubClientOCCURRENCEConnectTests, - SharedIoTHubClientOCCURRENCEDisconnectTests, - SharedIoTHubClientOCCURRENCENewSastokenRequired, - SharedIoTHubClientOCCURRENCEBackgroundException, - SharedIoTHubClientCreateFromConnectionStringTests, - SharedIoTHubDeviceClientCreateFromSastokenTests, - SharedIoTHubDeviceClientCreateFromSymmetricKeyTests, - SharedIoTHubDeviceClientCreateFromX509CertificateTests, - SharedIoTHubModuleClientCreateFromSastokenTests, - SharedIoTHubModuleClientCreateFromX509CertificateTests, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, -) - -pytestmark = pytest.mark.asyncio -logging.basicConfig(level=logging.DEBUG) - -# Python 3.6 only supports pytest_asyncio==0.16.0 which doesn't have pytest_asyncio.fixture. -try: - asyncio_fixture = pytest_asyncio.fixture -except AttributeError: - asyncio_fixture = pytest.fixture - - -async def create_completed_future(result=None): - f = asyncio.Future() - f.set_result(result) - return f - - -########################## -# SHARED CLIENT FIXTURES # -########################## -@pytest.fixture(params=["Handler Function", "Handler Coroutine"]) -def handler(request): - if request.param == "Handler Function": - - def _handler_function(arg): - pass - - return _handler_function - - else: - - async def _handler_coroutine(arg): - pass - - return _handler_coroutine - - -####################### -# SHARED CLIENT TESTS # -####################### -class SharedClientShutdownTests(object): - @pytest.mark.it("Performs a client disconnect (and everything that entails)") - async def test_calls_disconnect(self, mocker, client): - # We merely check that disconnect is called here. Doing so does several things, which - # are covered by the disconnect tests themselves. Those tests will NOT be duplicated here - client.disconnect = mocker.MagicMock() - client.disconnect.return_value = await create_completed_future(None) - assert client.disconnect.call_count == 0 - - await client.shutdown() - - assert client.disconnect.call_count == 1 - - @pytest.mark.it("Begins a 'shutdown' pipeline operation") - async def test_calls_pipeline_shutdown(self, mocker, client, mqtt_pipeline): - # mock out implicit disconnect - client.disconnect = mocker.MagicMock() - client.disconnect.return_value = await create_completed_future(None) - - await client.shutdown() - - assert mqtt_pipeline.shutdown.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of the 'shutdown' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline): - # mock out implicit disconnect - client.disconnect = mocker.MagicMock() - client.disconnect.return_value = await create_completed_future(None) - # mock out callback - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - await client.shutdown() - - # Assert callback is sent to pipeline - assert mqtt_pipeline.shutdown.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `shutdown` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - # The only other expected errors are unexpected ones. - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, pipeline_error, client_error - ): - # mock out implicit disconnect - client.disconnect = mocker.MagicMock() - client.disconnect.return_value = await create_completed_future(None) - - my_pipeline_error = pipeline_error() - - def fail_shutdown(callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.shutdown.side_effect = fail_shutdown - - with pytest.raises(client_error) as e_info: - await client.shutdown() - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it( - "Stops the client event handlers after the `shutdown` pipeline operation is complete" - ) - async def test_stops_client_event_handlers(self, mocker, client, mqtt_pipeline): - # mock out implicit disconnect - client.disconnect = mocker.MagicMock() - client.disconnect.return_value = await create_completed_future(None) - # Spy on handler manager stop. Note that while it does get called twice in shutdown, it - # only happens once here because we have mocked disconnect (where first stoppage) occurs - hm_stop_spy = mocker.spy(client._handler_manager, "stop") - - def check_handlers_and_complete(callback): - assert hm_stop_spy.call_count == 0 - callback() - - mqtt_pipeline.shutdown.side_effect = check_handlers_and_complete - - await client.shutdown() - - assert hm_stop_spy.call_count == 1 - assert hm_stop_spy.call_args == mocker.call(receiver_handlers_only=False) - - -class SharedClientConnectTests(object): - @pytest.mark.it("Begins a 'connect' pipeline operation") - async def test_calls_pipeline_connect(self, client, mqtt_pipeline): - await client.connect() - assert mqtt_pipeline.connect.call_count == 1 - - @pytest.mark.it("Waits for the completion of the 'connect' pipeline operation before returning") - async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline): - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - await client.connect() - - # Assert callback is sent to pipeline - assert mqtt_pipeline.connect.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `connect` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.TlsExchangeAuthError, - client_exceptions.ClientError, - id="TlsExchangeAuthError->ClientError", - ), - pytest.param( - pipeline_exceptions.ProtocolProxyError, - client_exceptions.ClientError, - id="ProtocolProxyError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout->OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - - def fail_connect(callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.connect = mocker.MagicMock(side_effect=fail_connect) - with pytest.raises(client_error) as e_info: - await client.connect() - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.connect.call_count == 1 - - -class SharedClientDisconnectTests(object): - @pytest.mark.it( - "Runs a 'disconnect' pipeline operation, stops the receiver handlers, then runs a second 'disconnect' pipeline operation" - ) - async def test_calls_pipeline_disconnect(self, mocker, client, mqtt_pipeline): - manager_mock = mocker.MagicMock() - client._handler_manager = mocker.MagicMock() - manager_mock.attach_mock(mqtt_pipeline.disconnect, "disconnect") - manager_mock.attach_mock(client._handler_manager.stop, "stop") - - await client.disconnect() - assert mqtt_pipeline.disconnect.call_count == 2 - assert client._handler_manager.stop.call_count == 1 - assert manager_mock.mock_calls == [ - mocker.call.disconnect(callback=mocker.ANY), - mocker.call.stop(receiver_handlers_only=True), - mocker.call.disconnect(callback=mocker.ANY), - ] - - @pytest.mark.it( - "Waits for the completion of both 'disconnect' pipeline operations before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline): - # Make a mock that returns two different objects (since there are two AwaitableCallbacks used) - cb_mock1 = mocker.MagicMock() - cb_mock2 = mocker.MagicMock() - cb_mock1.completion.return_value = await create_completed_future(None) - cb_mock2.completion.return_value = await create_completed_future(None) - cb_mock_init = mocker.patch.object(async_adapter, "AwaitableCallback") - cb_mock_init.side_effect = [cb_mock1, cb_mock2] - - await client.disconnect() - - # Disconnect called twice - assert mqtt_pipeline.disconnect.call_count == 2 - # Assert callbacks sent to pipeline - assert mqtt_pipeline.disconnect.call_args_list[0][1]["callback"] is cb_mock1 - assert mqtt_pipeline.disconnect.call_args_list[1][1]["callback"] is cb_mock2 - # Assert callback completions were waited upon - assert cb_mock1.completion.call_count == 1 - assert cb_mock2.completion.call_count == 1 - - # Give the AwaitableCallback mock a standardized return value again, since - # .disconnect() will be called in cleanup - cb_mock_init.side_effect = None - cb_mock_init.return_value.completion.return_value = await create_completed_future(None) - - @pytest.mark.it( - "Raises a client error if the `disconnect` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - - def fail_disconnect(callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.disconnect.side_effect = fail_disconnect - with pytest.raises(client_error) as e_info: - await client.disconnect() - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.disconnect.call_count == 1 - - # Unset the side effect, since disconnect is used to clean up fixtures. - mqtt_pipeline.disconnect.side_effect = None - - -class SharedClientUpdateSasTokenTests(object): - # NOTE: Classes that inherit from this class must define some additional fixtures not included - # here, which will be specific to a device or module: - # - sas_config: returns an IoTHubPipelineConfiguration configured for Device/Module - # - uri: A uri that matches the uri in the SAS from sas_token_string fixture - # - nonmatching_uri: A uri that does NOT match to the uri in the SAS from sas_token_string - # - invalid_uri: A uri that is invalid (poorly formed, missing data, etc.) - - @pytest.fixture - def device_id(self, sas_token_string): - # NOTE: This is kind of unconventional, but this is the easiest way to extract the - # device id from a sastoken string - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - device_id = token_uri_pieces[2] - return device_id - - @pytest.fixture - def hostname(self, sas_token_string): - # NOTE: This is kind of unconventional, but this is the easiest way to extract the - # hostname from a sastoken string - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - hostname = token_uri_pieces[0] - return hostname - - @pytest.fixture - def sas_client(self, client_class, mqtt_pipeline, http_pipeline, sas_config): - """Client configured as if using user-provided, non-renewable SAS auth""" - mqtt_pipeline.pipeline_configuration = sas_config - http_pipeline.pipeline_configuration = sas_config - return client_class(mqtt_pipeline, http_pipeline) - - @pytest.fixture - def sas_client_manual_cb( - self, client_class, mqtt_pipeline_manual_cb, http_pipeline_manual_cb, sas_config - ): - mqtt_pipeline_manual_cb.pipeline_configuration = sas_config - http_pipeline_manual_cb.pipeline_configuration = sas_config - return client_class(mqtt_pipeline_manual_cb, http_pipeline_manual_cb) - - @pytest.fixture - def new_sas_token_string(self, uri): - # New SASToken String that matches old device id, module_id and hostname - signature = "AvCQCS7uVk8Lxau7rBs/jek4iwENIwLwpEV7NIJySc0=" - new_token_string = "SharedAccessSignature sr={uri}&sig={signature}&se={expiry}".format( - uri=urllib.parse.quote(uri, safe=""), - signature=urllib.parse.quote(signature, safe=""), - expiry=int(time.time()) + 3600, - ) - return new_token_string - - @pytest.mark.it( - "Creates a new NonRenewableSasToken and sets it on the PipelineConfig, if the new SAS Token string matches the existing SAS Token's information" - ) - async def test_updates_token_if_match_vals(self, sas_client, new_sas_token_string): - old_sas_token_string = str(sas_client._mqtt_pipeline.pipeline_configuration.sastoken) - - # Update to new token - await sas_client.update_sastoken(new_sas_token_string) - - # Sastoken was updated - assert ( - str(sas_client._mqtt_pipeline.pipeline_configuration.sastoken) == new_sas_token_string - ) - assert ( - str(sas_client._mqtt_pipeline.pipeline_configuration.sastoken) != old_sas_token_string - ) - - @pytest.mark.it("Begins a 'reauthorize connection' pipeline operation") - async def test_calls_pipeline_reauthorize( - self, sas_client, new_sas_token_string, mqtt_pipeline - ): - await sas_client.update_sastoken(new_sas_token_string) - assert mqtt_pipeline.reauthorize_connection.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of the 'reauthorize connection' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion( - self, mocker, sas_client, mqtt_pipeline, new_sas_token_string - ): - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - await sas_client.update_sastoken(new_sas_token_string) - - # Assert callback is sent to pipeline - assert mqtt_pipeline.reauthorize_connection.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a ClientError if the 'reauthorize connection' pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.TlsExchangeAuthError, - client_exceptions.ClientError, - id="TlsExchangeAuthError->ClientError", - ), - pytest.param( - pipeline_exceptions.ProtocolProxyError, - client_exceptions.ClientError, - id="ProtocolProxyError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout->OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, sas_client, mqtt_pipeline, new_sas_token_string, client_error, pipeline_error - ): - # NOTE: If/When the MQTT pipeline is updated so that the reauthorize op waits for - # reconnection in order to return (currently it just waits for the disconnect), - # there will need to be additional connect-related errors in the parametrization. - my_pipeline_error = pipeline_error() - - def fail_reauth(callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.reauthorize_connection = mocker.MagicMock(side_effect=fail_reauth) - with pytest.raises(client_error) as e_info: - await sas_client.update_sastoken(new_sas_token_string) - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.reauthorize_connection.call_count == 1 - - @pytest.mark.it( - "Raises a ClientError if the client was created with an X509 certificate instead of SAS" - ) - async def test_created_with_x509(self, mocker, sas_client, new_sas_token_string): - # Modify client to seem as if created with X509 - x509_client = sas_client - x509_client._mqtt_pipeline.pipeline_configuration.sastoken = None - x509_client._mqtt_pipeline.pipeline_configuration.x509 = mocker.MagicMock() - - # Client raises error - with pytest.raises(client_exceptions.ClientError): - await x509_client.update_sastoken(new_sas_token_string) - - @pytest.mark.it( - "Raises a ClientError if the client was created with a renewable, non-user provided SAS (e.g. from connection string, symmetric key, etc.)" - ) - async def test_created_with_renewable_sas(self, mocker, sas_client, uri, new_sas_token_string): - # Modify client to seem as if created with renewable SAS - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" - renewable_token = st.RenewableSasToken(uri, mock_signing_mechanism) - sas_client._mqtt_pipeline.pipeline_configuration.sastoken = renewable_token - - # Client raises error - with pytest.raises(client_exceptions.ClientError): - await sas_client.update_sastoken(new_sas_token_string) - - @pytest.mark.it("Raises a ValueError if there is an error creating a new NonRenewableSasToken") - async def test_token_error(self, mocker, sas_client, new_sas_token_string): - # NOTE: specific inputs that could cause this are tested in the sastoken test module - sastoken_mock = mocker.patch.object(st.NonRenewableSasToken, "__init__") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - await sas_client.update_sastoken(new_sas_token_string) - assert e_info.value.__cause__ is token_err - - @pytest.mark.it("Raises ValueError if the provided SAS token string has already expired") - async def test_expired_token(self, mocker, sas_client, uri): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote(uri, safe=""), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() - 3600), # expired - ) - - with pytest.raises(ValueError): - await sas_client.update_sastoken(sastoken_str) - - @pytest.mark.it( - "Raises ValueError if the provided SAS token string does not match the previous SAS details" - ) - async def test_nonmatching_uri_in_new_token(self, sas_client, nonmatching_uri): - signature = "AvCQCS7uVk8Lxau7rBs/jek4iwENIwLwpEV7NIJySc0=" - sastoken_str = "SharedAccessSignature sr={uri}&sig={signature}&se={expiry}".format( - uri=urllib.parse.quote(nonmatching_uri, safe=""), - signature=urllib.parse.quote(signature), - expiry=int(time.time()) + 3600, - ) - - with pytest.raises(ValueError): - await sas_client.update_sastoken(sastoken_str) - - @pytest.mark.it("Raises ValueError if the provided SAS token string has an invalid URI") - async def test_raises_value_error_invalid_uri(self, mocker, sas_client, invalid_uri): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote(invalid_uri, safe=""), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() + 3600), - ) - - with pytest.raises(ValueError): - await sas_client.update_sastoken(sastoken_str) - - -class SharedClientSendD2CMessageTests(object): - @pytest.mark.it("Begins a 'send_message' pipeline operation") - async def test_calls_pipeline_send_message(self, client, mqtt_pipeline, message): - await client.send_message(message) - assert mqtt_pipeline.send_message.call_count == 1 - assert mqtt_pipeline.send_message.call_args[0][0] is message - - @pytest.mark.it( - "Waits for the completion of the 'send_message' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline, message): - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - await client.send_message(message) - - # Assert callback is sent to pipeline - assert mqtt_pipeline.send_message.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `send_message` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, message, client_error, pipeline_error - ): - my_pipeline_error = pipeline_error() - - def fail_send_message(message, callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.send_message = mocker.MagicMock(side_effect=fail_send_message) - with pytest.raises(client_error) as e_info: - await client.send_message(message) - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.send_message.call_count == 1 - - @pytest.mark.it( - "Wraps 'message' input parameter in a Message object if it is not a Message object" - ) - @pytest.mark.parametrize( - "message_input", - [ - pytest.param("message", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - async def test_wraps_data_in_message_and_calls_pipeline_send_message( - self, client, mqtt_pipeline, message_input - ): - await client.send_message(message_input) - assert mqtt_pipeline.send_message.call_count == 1 - sent_message = mqtt_pipeline.send_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == message_input - - @pytest.mark.it("Raises error when message data size is greater than 256 KB") - async def test_raises_error_when_message_data_greater_than_256(self, client, mqtt_pipeline): - data_input = "serpensortia" * 256000 - message = Message(data_input) - with pytest.raises(ValueError) as e_info: - await client.send_message(message) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_message.call_count == 0 - - @pytest.mark.it("Raises error when message size is greater than 256 KB") - async def test_raises_error_when_message_size_greater_than_256(self, client, mqtt_pipeline): - data_input = "serpensortia" - message = Message(data_input) - message.custom_properties["spell"] = data_input * 256000 - with pytest.raises(ValueError) as e_info: - await client.send_message(message) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_message.call_count == 0 - - @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") - async def test_raises_error_when_message_data_equal_to_256(self, client, mqtt_pipeline): - data_input = "a" * 262095 - message = Message(data_input) - # This check was put as message class may undergo the default content type encoding change - # and the above calculation will change. - if message.get_size() != device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - assert False - - await client.send_message(message) - - assert mqtt_pipeline.send_message.call_count == 1 - sent_message = mqtt_pipeline.send_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == data_input - - -class SharedClientReceiveMethodRequestTests(object): - @pytest.mark.it("Implicitly enables methods feature if not already enabled") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - async def test_enables_methods_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline, method_name - ): - # patch this so receive_method_request won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - # Verify Input Messaging enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - False # Method Requests will appear disabled - ) - await client.receive_method_request(method_name) - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.METHODS - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify Input Messaging not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - True # Input Messages will appear enabled - ) - await client.receive_method_request(method_name) - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it( - "Returns a MethodRequest from the generic method inbox, if available, when called without method name" - ) - async def test_called_without_method_name_returns_method_request_from_generic_method_inbox( - self, mocker, client - ): - request = MethodRequest(request_id="1", name="some_method", payload={"key": "value"}) - inbox_mock = mocker.MagicMock(autospec=AsyncClientInbox) - inbox_mock.get.return_value = await create_completed_future(request) - manager_get_inbox_mock = mocker.patch.object( - target=client._inbox_manager, - attribute="get_method_request_inbox", - return_value=inbox_mock, - ) - - received_request = await client.receive_method_request() - assert manager_get_inbox_mock.call_count == 1 - assert manager_get_inbox_mock.call_args == mocker.call(None) - assert inbox_mock.get.call_count == 1 - assert received_request is received_request - - @pytest.mark.it( - "Returns MethodRequest from the corresponding method inbox, if available, when called with a method name" - ) - async def test_called_with_method_name_returns_method_request_from_named_method_inbox( - self, mocker, client - ): - method_name = "some_method" - request = MethodRequest(request_id="1", name=method_name, payload={"key": "value"}) - inbox_mock = mocker.MagicMock(autospec=AsyncClientInbox) - inbox_mock.get.return_value = await create_completed_future(request) - manager_get_inbox_mock = mocker.patch.object( - target=client._inbox_manager, - attribute="get_method_request_inbox", - return_value=inbox_mock, - ) - - received_request = await client.receive_method_request(method_name) - assert manager_get_inbox_mock.call_count == 1 - assert manager_get_inbox_mock.call_args == mocker.call(method_name) - assert inbox_mock.get.call_count == 1 - assert received_request is received_request - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - async def test_receive_mode_not_set(self, mocker, client, method_name): - # patch this so receive_method_request won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - await client.receive_method_request(method_name) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - async def test_receive_mode_set_api(self, mocker, client, method_name): - # patch this so receive_method_request won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - client._receive_type = RECEIVE_TYPE_API - await client.receive_method_request(method_name) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - async def test_receive_mode_set_handler(self, mocker, client, method_name, mqtt_pipeline): - # patch this so receive_method_request won't block - inbox_get_mock = mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - await client.receive_method_request(method_name) - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_get_mock.call_count == 0 - - -class SharedClientSendMethodResponseTests(object): - @pytest.mark.it("Begins a 'send_method_response' pipeline operation") - async def test_send_method_response_calls_pipeline( - self, client, mqtt_pipeline, method_response - ): - await client.send_method_response(method_response) - assert mqtt_pipeline.send_method_response.call_count == 1 - assert mqtt_pipeline.send_method_response.call_args[0][0] is method_response - - @pytest.mark.it( - "Waits for the completion of the 'send_method_response' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion( - self, mocker, client, mqtt_pipeline, method_response - ): - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - await client.send_method_response(method_response) - - # Assert callback is sent to pipeline - assert mqtt_pipeline.send_method_response.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `send_method_response` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, method_response, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - - def fail_send_method_response(response, callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.send_method_response = mocker.MagicMock(side_effect=fail_send_method_response) - with pytest.raises(client_error) as e_info: - await client.send_method_response(method_response) - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.send_method_response.call_count == 1 - - -class SharedClientGetTwinTests(object): - @pytest.mark.it("Implicitly enables twin messaging feature if not already enabled") - async def test_enables_twin_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline, fake_twin - ): - # patch this so get_twin won't block - def immediate_callback(callback): - callback(twin=fake_twin) - - mocker.patch.object(mqtt_pipeline, "get_twin", side_effect=immediate_callback) - - # Verify twin enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False # twin will appear disabled - await client.get_twin() - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.TWIN - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify twin not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # twin will appear enabled - await client.get_twin() - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Begins a 'get_twin' pipeline operation") - async def test_get_twin_calls_pipeline(self, client, mqtt_pipeline, mocker, fake_twin): - def immediate_callback(callback): - callback(twin=fake_twin) - - mocker.patch.object(mqtt_pipeline, "get_twin", side_effect=immediate_callback) - await client.get_twin() - assert mqtt_pipeline.get_twin.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of the 'get_twin' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline): - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # twin will appear enabled - - await client.get_twin() - - # Assert callback is sent to pipeline - mqtt_pipeline.get_twin.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `get_twin` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - - def fail_get_twin(callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.get_twin = mocker.MagicMock(side_effect=fail_get_twin) - with pytest.raises(client_error) as e_info: - await client.get_twin() - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.get_twin.call_count == 1 - - @pytest.mark.it("Returns the twin that the pipeline returned") - async def test_verifies_twin_returned(self, mocker, client, mqtt_pipeline, fake_twin): - - # make the pipeline the twin - def immediate_callback(callback): - callback(twin=fake_twin) - - mocker.patch.object(mqtt_pipeline, "get_twin", side_effect=immediate_callback) - - returned_twin = await client.get_twin() - assert returned_twin == fake_twin - - -class SharedClientPatchTwinReportedPropertiesTests(object): - @pytest.mark.it("Implicitly enables twin messaging feature if not already enabled") - async def test_enables_twin_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline, twin_patch_reported - ): - # patch this so x_get_twin won't block - def immediate_callback(patch, callback): - callback() - - mocker.patch.object( - mqtt_pipeline, "patch_twin_reported_properties", side_effect=immediate_callback - ) - - # Verify twin enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False # twin will appear disabled - await client.patch_twin_reported_properties(twin_patch_reported) - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.TWIN - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify twin not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # twin will appear enabled - await client.patch_twin_reported_properties(twin_patch_reported) - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Begins a 'patch_twin_reported_properties' pipeline operation") - async def test_patch_twin_reported_properties_calls_pipeline( - self, client, mqtt_pipeline, twin_patch_reported - ): - await client.patch_twin_reported_properties(twin_patch_reported) - assert mqtt_pipeline.patch_twin_reported_properties.call_count == 1 - assert ( - mqtt_pipeline.patch_twin_reported_properties.call_args[1]["patch"] - is twin_patch_reported - ) - - @pytest.mark.it( - "Waits for the completion of the 'patch_twin_reported_properties' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion( - self, mocker, client, mqtt_pipeline, twin_patch_reported - ): - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # twin will appear enabled - - await client.patch_twin_reported_properties(twin_patch_reported) - - # Assert callback is sent to pipeline - assert mqtt_pipeline.patch_twin_reported_properties.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `patch_twin_reported_properties` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, twin_patch_reported, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - - def fail_patch_twin_reported_properties(patch, callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.patch_twin_reported_properties = mocker.MagicMock( - side_effect=fail_patch_twin_reported_properties - ) - with pytest.raises(client_error) as e_info: - await client.patch_twin_reported_properties(twin_patch_reported) - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.patch_twin_reported_properties.call_count == 1 - - -class SharedClientReceiveTwinDesiredPropertiesPatchTests(object): - @pytest.mark.it("Implicitly enables twin patch messaging feature if not already enabled") - async def test_enables_c2d_messaging_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline - ): - # patch this receive_twin_desired_properties_patch won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - # Verify twin patches are enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - False # twin patches will appear disabled - ) - await client.receive_twin_desired_properties_patch() - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.TWIN_PATCHES - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify twin patches are not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - True # twin patches will appear enabled - ) - await client.receive_twin_desired_properties_patch() - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Returns a message from the twin patch inbox, if available") - async def test_returns_message_from_twin_patch_inbox(self, mocker, client, twin_patch_desired): - inbox_mock = mocker.MagicMock(autospec=AsyncClientInbox) - inbox_mock.get.return_value = await create_completed_future(twin_patch_desired) - manager_get_inbox_mock = mocker.patch.object( - client._inbox_manager, "get_twin_patch_inbox", return_value=inbox_mock - ) - - received_patch = await client.receive_twin_desired_properties_patch() - assert manager_get_inbox_mock.call_count == 1 - assert inbox_mock.get.call_count == 1 - assert received_patch is twin_patch_desired - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - async def test_receive_mode_not_set(self, mocker, client): - # patch this so API won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - await client.receive_twin_desired_properties_patch() - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - async def test_receive_mode_set_api(self, mocker, client): - # patch this so API won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - client._receive_type = RECEIVE_TYPE_API - await client.receive_twin_desired_properties_patch() - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - async def test_receive_mode_set_handler(self, mocker, client, mqtt_pipeline): - # patch this so API won't block - inbox_get_mock = mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - await client.receive_twin_desired_properties_patch() - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_get_mock.call_count == 0 - - -# ################ -# # DEVICE TESTS # -# ################ -class IoTHubDeviceClientTestsConfig(object): - @pytest.fixture - def client_class(self): - return IoTHubDeviceClient - - @asyncio_fixture - async def client(self, mqtt_pipeline, http_pipeline): - """This client automatically resolves callbacks sent to the pipeline. - It should be used for the majority of tests. - """ - client = IoTHubDeviceClient(mqtt_pipeline, http_pipeline) - yield client - await client.shutdown() - - @pytest.fixture - def connection_string(self, device_connection_string): - """This fixture is parametrized to prove all valid device connection strings. - See client_fixtures.py - """ - return device_connection_string - - @pytest.fixture - def sas_token_string(self, device_sas_token_string): - return device_sas_token_string - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - Instantiation") -class TestIoTHubDeviceClientInstantiation( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientInstantiationTests -): - @pytest.mark.it("Sets on_c2d_message_received handler in the MQTTPipeline") - async def test_sets_on_c2d_message_received_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_c2d_message_received is not None - assert ( - client._mqtt_pipeline.on_c2d_message_received == client._inbox_manager.route_c2d_message - ) - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_connection_string()") -class TestIoTHubDeviceClientCreateFromConnectionString( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_sastoken()") -class TestIoTHubDeviceClientCreateFromSastoken( - IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromSastokenTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_symmetric_key()") -class TestConfigurationCreateIoTHubDeviceClientFromSymmetricKey( - IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromSymmetricKeyTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_x509_certificate()") -class TestIoTHubDeviceClientCreateFromX509Certificate( - IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromX509CertificateTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .shutdown()") -class TestIoTHubDeviceClientShutdown(IoTHubDeviceClientTestsConfig, SharedClientShutdownTests): - @pytest.fixture - def client(self, mqtt_pipeline, http_pipeline): - """Override the client so that it doesn't shutdown during cleanup. - Shutdown can only be done once, and it is under test, so shutting down again during cleanup - will fail. - """ - return IoTHubDeviceClient(mqtt_pipeline, http_pipeline) - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .update_sastoken()") -class TestIoTHubDeviceClientUpdateSasToken( - IoTHubDeviceClientTestsConfig, SharedClientUpdateSasTokenTests -): - @pytest.fixture - def sas_config(self, sas_token_string): - """PipelineConfig set up as if using user-provided, non-renewable SAS auth""" - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - hostname = token_uri_pieces[0] - device_id = token_uri_pieces[2] - sas_config = IoTHubPipelineConfig(hostname=hostname, device_id=device_id, sastoken=sastoken) - return sas_config - - @pytest.fixture - def uri(self, hostname, device_id): - return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) - - @pytest.fixture(params=["Nonmatching Device ID", "Nonmatching Hostname"]) - def nonmatching_uri(self, request, device_id, hostname): - # NOTE: It would be preferable to have this as a parametrization on a test rather than a - # fixture, however, we need to use the device_id and hostname fixtures in order to ensure - # tests don't break when other fixtures change, and you can't include fixtures in a - # parametrization, so this also has to be a fixture - uri_format = "{hostname}/devices/{device_id}" - if request.param == "Nonmatching Device ID": - return uri_format.format(hostname=hostname, device_id="nonmatching_device") - else: - return uri_format.format(hostname="nonmatching_hostname", device_id=device_id) - - @pytest.fixture( - params=["Too short", "Too long", "Incorrectly formatted device notation", "Module URI"] - ) - def invalid_uri(self, request, device_id, hostname): - # NOTE: As in the nonmatching_uri fixture above, this is a workaround for parametrization - # that allows the usage of other fixtures in the parametrized value. Weird pattern, but - # necessary to ensure stability of the tests over time. - if request.param == "Too short": - # Doesn't have device ID - return hostname + "/devices" - elif request.param == "Too long": - # Extraneous value at the end - return "{}/devices/{}/somethingElse".format(hostname, device_id) - elif request.param == "Incorrectly formatted device notation": - # Doesn't have '/devices/' - return "{}/not-devices/{}".format(hostname, device_id) - else: - # Valid... for a Module... but this is a Device - return "{}/devices/{}/modules/my_module".format(hostname, device_id) - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .connect()") -class TestIoTHubDeviceClientConnect(IoTHubDeviceClientTestsConfig, SharedClientConnectTests): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .disconnect()") -class TestIoTHubDeviceClientDisconnect(IoTHubDeviceClientTestsConfig, SharedClientDisconnectTests): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .send_message()") -class TestIoTHubDeviceClientSendD2CMessage( - IoTHubDeviceClientTestsConfig, SharedClientSendD2CMessageTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .receive_message()") -class TestIoTHubDeviceClientReceiveC2DMessage(IoTHubDeviceClientTestsConfig): - @pytest.mark.it("Implicitly enables C2D messaging feature if not already enabled") - async def test_enables_c2d_messaging_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline - ): - # patch this receive_message won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - # Verify C2D Messaging enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False # C2D will appear disabled - await client.receive_message() - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.C2D_MSG - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify C2D Messaging not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # C2D will appear enabled - await client.receive_message() - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Returns a message from the C2D inbox, if available") - async def test_returns_message_from_c2d_inbox(self, mocker, client, message): - inbox_mock = mocker.MagicMock(autospec=AsyncClientInbox) - inbox_mock.get.return_value = await create_completed_future(message) - manager_get_inbox_mock = mocker.patch.object( - client._inbox_manager, "get_c2d_message_inbox", return_value=inbox_mock - ) - - received_message = await client.receive_message() - assert manager_get_inbox_mock.call_count == 1 - assert inbox_mock.get.call_count == 1 - assert received_message is message - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - async def test_receive_mode_not_set(self, mocker, client): - # patch this so API won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - await client.receive_message() - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - async def test_receive_mode_set_api(self, mocker, client): - # patch this so API won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - client._receive_type = RECEIVE_TYPE_API - await client.receive_message() - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - async def test_receive_mode_set_handler(self, mocker, client, mqtt_pipeline): - # patch this so API won't block - inbox_get_mock = mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - await client.receive_message() - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_get_mock.call_count == 0 - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .receive_method_request()") -class TestIoTHubDeviceClientReceiveMethodRequest( - IoTHubDeviceClientTestsConfig, SharedClientReceiveMethodRequestTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .send_method_response()") -class TestIoTHubDeviceClientSendMethodResponse( - IoTHubDeviceClientTestsConfig, SharedClientSendMethodResponseTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .get_twin()") -class TestIoTHubDeviceClientGetTwin(IoTHubDeviceClientTestsConfig, SharedClientGetTwinTests): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .patch_twin_reported_properties()") -class TestIoTHubDeviceClientPatchTwinReportedProperties( - IoTHubDeviceClientTestsConfig, SharedClientPatchTwinReportedPropertiesTests -): - pass - - -@pytest.mark.describe( - "IoTHubDeviceClient (Asynchronous) - .receive_twin_desired_properties_patch()" -) -class TestIoTHubDeviceClientReceiveTwinDesiredPropertiesPatch( - IoTHubDeviceClientTestsConfig, SharedClientReceiveTwinDesiredPropertiesPatchTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .get_storage_info_for_blob()") -class TestIoTHubDeviceClientGetStorageInfo(IoTHubDeviceClientTestsConfig): - @pytest.mark.it("Begins a 'get_storage_info_for_blob' HTTPPipeline operation") - async def test_calls_pipeline_get_storage_info_for_blob(self, client, http_pipeline): - fake_blob_name = "__fake_blob_name__" - await client.get_storage_info_for_blob(fake_blob_name) - assert http_pipeline.get_storage_info_for_blob.call_count == 1 - assert http_pipeline.get_storage_info_for_blob.call_args[1]["blob_name"] is fake_blob_name - - @pytest.mark.it( - "Waits for the completion of the 'get_storage_info_for_blob' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, http_pipeline): - fake_blob_name = "__fake_blob_name__" - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - await client.get_storage_info_for_blob(fake_blob_name) - - # Assert callback is sent to pipeline - assert http_pipeline.get_storage_info_for_blob.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `get_storage_info_for_blob` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, http_pipeline, pipeline_error, client_error - ): - fake_blob_name = "__fake_blob_name__" - - my_pipeline_error = pipeline_error() - - def fail_get_storage_info_for_blob(blob_name, callback): - callback(error=my_pipeline_error) - - http_pipeline.get_storage_info_for_blob = mocker.MagicMock( - side_effect=fail_get_storage_info_for_blob - ) - - with pytest.raises(client_error) as e_info: - await client.get_storage_info_for_blob(fake_blob_name) - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it("Returns a storage_info object upon successful completion") - async def test_returns_storage_info(self, mocker, client, http_pipeline): - fake_blob_name = "__fake_blob_name__" - fake_storage_info = "__fake_storage_info__" - received_storage_info = await client.get_storage_info_for_blob(fake_blob_name) - assert http_pipeline.get_storage_info_for_blob.call_count == 1 - assert http_pipeline.get_storage_info_for_blob.call_args[1]["blob_name"] is fake_blob_name - - assert ( - received_storage_info is fake_storage_info - ) # Note: the return value this is checking for is defined in client_fixtures.py - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) -.notify_blob_upload_status()") -class TestIoTHubDeviceClientNotifyBlobUploadStatus(IoTHubDeviceClientTestsConfig): - @pytest.mark.it("Begins a 'notify_blob_upload_status' HTTPPipeline operation") - async def test_calls_pipeline_notify_blob_upload_status(self, client, http_pipeline): - correlation_id = "__fake_correlation_id__" - is_success = "__fake_is_success__" - status_code = "__fake_status_code__" - status_description = "__fake_status_description__" - await client.notify_blob_upload_status( - correlation_id, is_success, status_code, status_description - ) - kwargs = http_pipeline.notify_blob_upload_status.call_args[1] - assert http_pipeline.notify_blob_upload_status.call_count == 1 - assert kwargs["correlation_id"] is correlation_id - assert kwargs["is_success"] is is_success - assert kwargs["status_code"] is status_code - assert kwargs["status_description"] is status_description - - @pytest.mark.it( - "Waits for the completion of the 'notify_blob_upload_status' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, http_pipeline): - correlation_id = "__fake_correlation_id__" - is_success = "__fake_is_success__" - status_code = "__fake_status_code__" - status_description = "__fake_status_description__" - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - await client.notify_blob_upload_status( - correlation_id, is_success, status_code, status_description - ) - - # Assert callback is sent to pipeline - assert http_pipeline.notify_blob_upload_status.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `notify_blob_upload_status` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, http_pipeline, pipeline_error, client_error - ): - correlation_id = "__fake_correlation_id__" - is_success = "__fake_is_success__" - status_code = "__fake_status_code__" - status_description = "__fake_status_description__" - my_pipeline_error = pipeline_error() - - def fail_notify_blob_upload_status( - correlation_id, is_success, status_code, status_description, callback - ): - callback(error=my_pipeline_error) - - http_pipeline.notify_blob_upload_status = mocker.MagicMock( - side_effect=fail_notify_blob_upload_status - ) - - with pytest.raises(client_error) as e_info: - await client.notify_blob_upload_status( - correlation_id, is_success, status_code, status_description - ) - assert e_info.value.__cause__ is my_pipeline_error - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .on_message_received") -class TestIoTHubDeviceClientPROPERTYOnMessageReceivedHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_message_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.C2D_MSG - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .on_method_request_received") -class TestIoTHubDeviceClientPROPERTYOnMethodRequestReceivedHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_method_request_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.METHODS - - -@pytest.mark.describe( - "IoTHubDeviceClient (Asynchronous) - PROPERTY .on_twin_desired_properties_patch_received" -) -class TestIoTHubDeviceClientPROPERTYOnTwinDesiredPropertiesPatchReceivedHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_twin_desired_properties_patch_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.TWIN_PATCHES - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .on_connection_state_change") -class TestIoTHubDeviceClientPROPERTYOnConnectionStateChangeHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_connection_state_change" - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .on_new_sastoken_required") -class TestIoTHubDeviceClientPROPERTYOnNewSastokenRequiredHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_new_sastoken_required" - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .on_background_exception") -class TestIoTHubDeviceClientPROPERTYOnBackgroundExceptionHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_background_exception" - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .connected") -class TestIoTHubDeviceClientPROPERTYConnected( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - OCCURRENCE: Connect") -class TestIoTHubDeviceClientOCCURRENCEConnect( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCEConnectTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - OCCURRENCE: Disconnect") -class TestIoTHubDeviceClientOCCURRENCEDisconnect( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCEDisconnectTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - OCCURRENCE: New Sastoken Required") -class TestIoTHubDeviceClientOCCURRENCENewSastokenRequired( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCENewSastokenRequired -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - OCCURRENCE: Background Exception") -class TestIoTHubDeviceClientOCCURRENCEBackgroundException( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCEBackgroundException -): - pass - - -################ -# MODULE TESTS # -################ -class IoTHubModuleClientTestsConfig(object): - @pytest.fixture - def client_class(self): - return IoTHubModuleClient - - @asyncio_fixture - async def client(self, mqtt_pipeline, http_pipeline): - """This client automatically resolves callbacks sent to the pipeline. - It should be used for the majority of tests. - """ - client = IoTHubModuleClient(mqtt_pipeline, http_pipeline) - yield client - await client.shutdown() - - @pytest.fixture - def connection_string(self, module_connection_string): - """This fixture is parametrized to prove all valid device connection strings. - See client_fixtures.py - """ - return module_connection_string - - @pytest.fixture - def sas_token_string(self, module_sas_token_string): - return module_sas_token_string - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - Instantiation") -class TestIoTHubModuleClientInstantiation( - IoTHubModuleClientTestsConfig, SharedIoTHubClientInstantiationTests -): - @pytest.mark.it("Sets on_input_message_received handler in the MQTTPipeline") - async def test_sets_on_input_message_received_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_input_message_received is not None - assert ( - client._mqtt_pipeline.on_input_message_received - == client._inbox_manager.route_input_message - ) - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_connection_string()") -class TestIoTHubModuleClientCreateFromConnectionString( - IoTHubModuleClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_sastoken()") -class TestIoTHubModuleClientCreateFromSastoken( - IoTHubModuleClientTestsConfig, SharedIoTHubModuleClientCreateFromSastokenTests -): - pass - - -@pytest.mark.describe( - "IoTHubModuleClient (Asynchronous) - .create_from_edge_environment() -- Edge Container Environment" -) -class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( - IoTHubModuleClientTestsConfig, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, -): - pass - - -@pytest.mark.describe( - "IoTHubModuleClient (Asynchronous) - .create_from_edge_environment() -- Edge Local Debug Environment" -) -class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv( - IoTHubModuleClientTestsConfig, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_x509_certificate()") -class TestIoTHubModuleClientCreateFromX509Certificate( - IoTHubModuleClientTestsConfig, SharedIoTHubModuleClientCreateFromX509CertificateTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .shutdown()") -class TestIoTHubModuleClientShutdown(IoTHubModuleClientTestsConfig, SharedClientShutdownTests): - @pytest.fixture - def client(self, mqtt_pipeline, http_pipeline): - """Override the client so that it doesn't shutdown during cleanup. - Shutdown can only be done once, and it is under test, so shutting down again during cleanup - will fail. - """ - return IoTHubModuleClient(mqtt_pipeline, http_pipeline) - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .update_sastoken()") -class TestIoTHubModuleClientUpdateSasToken( - IoTHubModuleClientTestsConfig, SharedClientUpdateSasTokenTests -): - @pytest.fixture - def module_id(self, sas_token_string): - # NOTE: This is kind of unconventional, but this is the easiest way to extract the - # module id from a sastoken string - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - module_id = token_uri_pieces[4] - return module_id - - @pytest.fixture - def uri(self, hostname, device_id, module_id): - return "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=hostname, device_id=device_id, module_id=module_id - ) - - @pytest.fixture( - params=["Nonmatching Device ID", "Nonmatching Module ID", "Nonmatching Hostname"] - ) - def nonmatching_uri(self, request, device_id, module_id, hostname): - # NOTE: It would be preferable to have this as a parametrization on a test rather than a - # fixture, however, we need to use the device_id and hostname fixtures in order to ensure - # tests don't break when other fixtures change, and you can't include fixtures in a - # parametrization, so this also has to be a fixture - uri_format = "{hostname}/devices/{device_id}/modules/{module_id}" - if request.param == "Nonmatching Device ID": - return uri_format.format( - hostname=hostname, device_id="nonmatching_device", module_id=module_id - ) - elif request.param == "Nonmatching Module ID": - return uri_format.format( - hostname=hostname, device_id=device_id, module_id="nonmatching_module" - ) - else: - return uri_format.format( - hostname="nonmatching_hostname", device_id=device_id, module_id=module_id - ) - - @pytest.fixture( - params=[ - "Too short", - "Too long", - "Incorrectly formatted device notation", - "Incorrectly formatted module notation", - "Device URI", - ] - ) - def invalid_uri(self, request, device_id, module_id, hostname): - # NOTE: As in the nonmatching_uri fixture above, this is a workaround for parametrization - # that allows the usage of other fixtures in the parametrized value. Weird pattern, but - # necessary to ensure stability of the tests over time. - if request.param == "Too short": - # Doesn't have module ID - return "{}/devices/{}/modules".format(hostname, device_id) - elif request.param == "Too long": - # Extraneous value at the end - return "{}/devices/{}/modules/{}/somethingElse".format(hostname, device_id, module_id) - elif request.param == "Incorrectly formatted device notation": - # Doesn't have '/devices/' - return "{}/not-devices/{}/modules/{}".format(hostname, device_id, module_id) - elif request.param == "Incorrectly formatted module notation": - # Doesn't have '/modules/' - return "{}/devices/{}/not-modules/{}".format(hostname, device_id, module_id) - else: - # Valid... for a Device... but this is a Module - return "{}/devices/{}/".format(hostname, device_id) - - @pytest.fixture - def sas_config(self, sas_token_string): - """PipelineConfig set up as if using user-provided, non-renewable SAS auth""" - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - hostname = token_uri_pieces[0] - device_id = token_uri_pieces[2] - module_id = token_uri_pieces[4] - sas_config = IoTHubPipelineConfig( - hostname=hostname, device_id=device_id, module_id=module_id, sastoken=sastoken - ) - return sas_config - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .connect()") -class TestIoTHubModuleClientConnect(IoTHubModuleClientTestsConfig, SharedClientConnectTests): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .disconnect()") -class TestIoTHubModuleClientDisconnect(IoTHubModuleClientTestsConfig, SharedClientDisconnectTests): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .send_message()") -class TestIoTHubNModuleClientSendD2CMessage( - IoTHubModuleClientTestsConfig, SharedClientSendD2CMessageTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .send_message_to_output()") -class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig): - @pytest.mark.it("Begins a 'send_output_message' pipeline operation") - async def test_calls_pipeline_send_message_to_output(self, client, mqtt_pipeline, message): - output_name = "some_output" - await client.send_message_to_output(message, output_name) - assert mqtt_pipeline.send_output_message.call_count == 1 - assert mqtt_pipeline.send_output_message.call_args[0][0] is message - assert message.output_name == output_name - - @pytest.mark.it( - "Waits for the completion of the 'send_output_message' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline, message): - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - output_name = "some_output" - await client.send_message_to_output(message, output_name) - - # Assert callback is sent to pipeline - assert mqtt_pipeline.send_output_message.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `send_output_message` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, mqtt_pipeline, message, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - - def fail_send_output_message(message, callback): - callback(error=my_pipeline_error) - - mqtt_pipeline.send_output_message = mocker.MagicMock(side_effect=fail_send_output_message) - with pytest.raises(client_error) as e_info: - output_name = "some_output" - await client.send_message_to_output(message, output_name) - assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.send_output_message.call_count == 1 - - @pytest.mark.it( - "Wraps 'message' input parameter in Message object if it is not a Message object" - ) - @pytest.mark.parametrize( - "message_input", - [ - pytest.param("message", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - async def test_send_message_to_output_calls_pipeline_wraps_data_in_message( - self, client, mqtt_pipeline, message_input - ): - output_name = "some_output" - await client.send_message_to_output(message_input, output_name) - assert mqtt_pipeline.send_output_message.call_count == 1 - sent_message = mqtt_pipeline.send_output_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == message_input - - @pytest.mark.it("Raises error when message data size is greater than 256 KB") - async def test_raises_error_when_message_to_output_data_greater_than_256( - self, client, mqtt_pipeline - ): - output_name = "some_output" - data_input = "serpensortia" * 256000 - message = Message(data_input) - with pytest.raises(ValueError) as e_info: - await client.send_message_to_output(message, output_name) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_message.call_count == 0 - - @pytest.mark.it("Raises error when message size is greater than 256 KB") - async def test_raises_error_when_message_to_output_size_greater_than_256( - self, client, mqtt_pipeline - ): - output_name = "some_output" - data_input = "serpensortia" - message = Message(data_input) - message.custom_properties["spell"] = data_input * 256000 - with pytest.raises(ValueError) as e_info: - await client.send_message_to_output(message, output_name) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_message.call_count == 0 - - @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") - async def test_raises_error_when_message_to_output_data_equal_to_256( - self, client, mqtt_pipeline - ): - output_name = "some_output" - data_input = "a" * 262095 - message = Message(data_input) - # This check was put as message class may undergo the default content type encoding change - # and the above calculation will change. - if message.get_size() != device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - assert False - - await client.send_message_to_output(message, output_name) - - assert mqtt_pipeline.send_output_message.call_count == 1 - sent_message = mqtt_pipeline.send_output_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == data_input - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .receive_message_on_input()") -class TestIoTHubModuleClientReceiveInputMessage(IoTHubModuleClientTestsConfig): - @pytest.mark.it("Implicitly enables input messaging feature if not already enabled") - async def test_enables_input_messaging_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline - ): - # patch this receive_message_on_input won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - input_name = "some_input" - - # Verify Input Messaging enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - False # Input Messages will appear disabled - ) - await client.receive_message_on_input(input_name) - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.INPUT_MSG - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify Input Messaging not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - True # Input Messages will appear enabled - ) - await client.receive_message_on_input(input_name) - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Returns a message from the input inbox, if available") - async def test_returns_message_from_input_inbox(self, mocker, client, message): - inbox_mock = mocker.MagicMock(autospec=AsyncClientInbox) - inbox_mock.get.return_value = await create_completed_future(message) - manager_get_inbox_mock = mocker.patch.object( - client._inbox_manager, "get_input_message_inbox", return_value=inbox_mock - ) - - input_name = "some_input" - received_message = await client.receive_message_on_input(input_name) - assert manager_get_inbox_mock.call_count == 1 - assert manager_get_inbox_mock.call_args == mocker.call(input_name) - assert inbox_mock.get.call_count == 1 - assert received_message is message - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - async def test_receive_mode_not_set(self, mocker, client): - # patch this so API won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - await client.receive_message_on_input("some_input") - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - async def test_receive_mode_set_api(self, mocker, client): - # patch this so API won't block - mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - - client._receive_type = RECEIVE_TYPE_API - await client.receive_message_on_input("some_input") - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - async def test_receive_mode_set_handler(self, mocker, client, mqtt_pipeline): - # patch this so API won't block - inbox_get_mock = mocker.patch.object( - AsyncClientInbox, "get", return_value=(await create_completed_future(None)) - ) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - await client.receive_message_on_input("some_input") - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_get_mock.call_count == 0 - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .receive_method_request()") -class TestIoTHubModuleClientReceiveMethodRequest( - IoTHubModuleClientTestsConfig, SharedClientReceiveMethodRequestTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .send_method_response()") -class TestIoTHubModuleClientSendMethodResponse( - IoTHubModuleClientTestsConfig, SharedClientSendMethodResponseTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .get_twin()") -class TestIoTHubModuleClientGetTwin(IoTHubModuleClientTestsConfig, SharedClientGetTwinTests): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .patch_twin_reported_properties()") -class TestIoTHubModuleClientPatchTwinReportedProperties( - IoTHubModuleClientTestsConfig, SharedClientPatchTwinReportedPropertiesTests -): - pass - - -@pytest.mark.describe( - "IoTHubModuleClient (Asynchronous) - .receive_twin_desired_properties_patch()" -) -class TestIoTHubModuleClientReceiveTwinDesiredPropertiesPatch( - IoTHubModuleClientTestsConfig, SharedClientReceiveTwinDesiredPropertiesPatchTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) -.invoke_method()") -class TestIoTHubModuleClientInvokeMethod(IoTHubModuleClientTestsConfig): - @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a device") - async def test_calls_pipeline_invoke_method_for_device(self, mocker, client, http_pipeline): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - await client.invoke_method(method_params, device_id) - assert http_pipeline.invoke_method.call_count == 1 - assert http_pipeline.invoke_method.call_args == mocker.call( - device_id, method_params, callback=mocker.ANY, module_id=None - ) - - @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a module") - async def test_calls_pipeline_invoke_method_for_module(self, mocker, client, http_pipeline): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - module_id = "__fake_module_id__" - await client.invoke_method(method_params, device_id, module_id=module_id) - assert http_pipeline.invoke_method.call_count == 1 - # assert http_pipeline.invoke_method.call_args[0][0] is device_id - # assert http_pipeline.invoke_method.call_args[0][1] is method_params - assert http_pipeline.invoke_method.call_args == mocker.call( - device_id, method_params, callback=mocker.ANY, module_id=module_id - ) - - @pytest.mark.it( - "Waits for the completion of the 'invoke_method' pipeline operation before returning" - ) - async def test_waits_for_pipeline_op_completion(self, mocker, client, http_pipeline): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - module_id = "__fake_module_id__" - cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value - cb_mock.completion.return_value = await create_completed_future(None) - - await client.invoke_method(method_params, device_id, module_id=module_id) - - # Assert callback is sent to pipeline - assert http_pipeline.invoke_method.call_args[1]["callback"] is cb_mock - # Assert callback completion is waited upon - assert cb_mock.completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `invoke_method` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_pipeline_op_error( - self, mocker, client, http_pipeline, pipeline_error, client_error - ): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - module_id = "__fake_module_id__" - my_pipeline_error = pipeline_error() - - def fail_invoke_method(method_params, device_id, callback, module_id=None): - return callback(error=my_pipeline_error) - - http_pipeline.invoke_method = mocker.MagicMock(side_effect=fail_invoke_method) - - with pytest.raises(client_error) as e_info: - await client.invoke_method(method_params, device_id, module_id=module_id) - - assert e_info.value.__cause__ is my_pipeline_error - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - PROPERTY .on_message_received") -class TestIoTHubModuleClientPROPERTYOnMessageReceivedHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_message_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.INPUT_MSG - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - PROPERTY .on_method_request_received") -class TestIoTHubModuleClientPROPERTYOnMethodRequestReceivedHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_method_request_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.METHODS - - -@pytest.mark.describe( - "IoTHubModuleClient (Asynchronous) - PROPERTY .on_twin_desired_properties_patch_received" -) -class TestIoTHubModuleClientPROPERTYOnTwinDesiredPropertiesPatchReceivedHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_twin_desired_properties_patch_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.TWIN_PATCHES - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - PROPERTY .on_connection_state_change") -class TestIoTHubModuleClientPROPERTYOnConnectionStateChangeHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_connection_state_change" - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - PROPERTY .on_new_sastoken_required") -class TestIoTHubModuleClientPROPERTYOnNewSastokenRequiredHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_new_sastoken_required" - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - PROPERTY .on_background_exception") -class TestIoTHubModuleClientPROPERTYOnBackgroundExceptionHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_background_exception" - - -@pytest.mark.describe("IoTHubModule (Asynchronous) - PROPERTY .connected") -class TestIoTHubModuleClientPROPERTYConnected( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - OCCURRENCE: Connect") -class TestIoTHubModuleClientOCCURRENCEConnect( - IoTHubModuleClientTestsConfig, SharedIoTHubClientOCCURRENCEConnectTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - OCCURRENCE: Disconnect") -class TestIoTHubModuleClientOCCURRENCEDisconnect( - IoTHubModuleClientTestsConfig, SharedIoTHubClientOCCURRENCEDisconnectTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - OCCURRENCE: New Sastoken Required") -class TestIoTHubModuleClientOCCURRENCENewSastokenRequired( - IoTHubModuleClientTestsConfig, SharedIoTHubClientOCCURRENCENewSastokenRequired -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - OCCURRENCE: Background Exception") -class TestIoTHubModuleClientOCCURRENCEBackgroundException( - IoTHubModuleClientTestsConfig, SharedIoTHubClientOCCURRENCEBackgroundException -): - pass diff --git a/tests/unit/iothub/aio/test_async_handler_manager.py b/tests/unit/iothub/aio/test_async_handler_manager.py deleted file mode 100644 index b3dc5c056..000000000 --- a/tests/unit/iothub/aio/test_async_handler_manager.py +++ /dev/null @@ -1,835 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import asyncio -import threading -import concurrent.futures -from azure.iot.device.common import handle_exceptions -from azure.iot.device.iothub import client_event -from azure.iot.device.iothub.aio.async_handler_manager import AsyncHandlerManager -from azure.iot.device.iothub.sync_handler_manager import HandlerManagerException -from azure.iot.device.iothub.sync_handler_manager import MESSAGE, METHOD, TWIN_DP_PATCH -from azure.iot.device.iothub.inbox_manager import InboxManager -from azure.iot.device.iothub.aio.async_inbox import AsyncClientInbox - -pytestmark = pytest.mark.asyncio -logging.basicConfig(level=logging.DEBUG) - -# NOTE ON TEST IMPLEMENTATION: -# Despite having significant shared implementation between the sync and async handler managers, -# there are not shared tests. This is because while both have the same set of requirements and -# APIs, the internal implementation is different to an extent that it simply isn't really possible -# to test them to an appropriate degree of correctness with a shared set of tests. -# This means we must be very careful to always change both test modules when a change is made to -# shared behavior, or when shared features are added. - -# NOTE ON TIMING/DELAY -# Several tests in this module have sleeps/delays in their implementation due to needing to wait -# for things to happen in other threads. - - -all_internal_receiver_handlers = [MESSAGE, METHOD, TWIN_DP_PATCH] -all_internal_client_event_handlers = [ - "_on_connection_state_change", - "_on_new_sastoken_required", - "_on_background_exception", -] -all_internal_handlers = all_internal_receiver_handlers + all_internal_client_event_handlers -all_receiver_handlers = [s.lstrip("_") for s in all_internal_receiver_handlers] -all_client_event_handlers = [s.lstrip("_") for s in all_internal_client_event_handlers] -all_handlers = all_receiver_handlers + all_client_event_handlers - - -class ThreadsafeMock(object): - """This class provides (some) Mock functionality in a threadsafe manner, specifically, it - ensures that the 'call_count' attribute will be accurate when the mock is called from another - thread. - - It does not cover ALL mock functionality, but more features could be added to it as necessary - """ - - def __init__(self): - self.call_count = 0 - self.lock = threading.Lock() - - def __call__(self, *args, **kwargs): - with self.lock: - self.call_count += 1 - - -@pytest.fixture -def inbox_manager(): - return InboxManager(inbox_type=AsyncClientInbox) - - -# ---------------------- -# We have to do some unfortunate things here in order to manually mock out handlers, to test -# tha they have been called. We can't use MagicMocks because not only do they not work well -# with coroutines, but especially if we are trying to test the very fact that functions and -# coroutines both work in the HandlerManager, replacing those things (i.e. the things under test) -# with a Mock.... really doesn't help us achieve that goal. - - -@pytest.fixture -def handler_checker(): - class HandlerChecker(object): - def __init__(self): - self.handler_called = False - self.handler_call_count = 0 - self.handler_call_args = None - self.lock = threading.Lock() - - return HandlerChecker() - - -@pytest.fixture(params=["Handler function", "Handler coroutine"]) -def handler(request, handler_checker): - if request.param == "Handler function": - - def some_handler_fn(*args): - with handler_checker.lock: - handler_checker.handler_called = True - handler_checker.handler_call_count += 1 - handler_checker.handler_call_args = args - - return some_handler_fn - - else: - - async def some_handler_coro(*args): - with handler_checker.lock: - handler_checker.handler_called = True - handler_checker.handler_call_count += 1 - handler_checker.handler_call_args = args - - return some_handler_coro - - -# ---------------------- - - -@pytest.mark.describe("AsyncHandlerManager - Instantiation") -class TestInstantiation(object): - @pytest.mark.it("Initializes handler properties to None") - @pytest.mark.parametrize("handler_name", all_handlers) - def test_handlers(self, inbox_manager, handler_name): - hm = AsyncHandlerManager(inbox_manager) - assert getattr(hm, handler_name) is None - - @pytest.mark.it("Initializes receiver handler runner task references to None") - @pytest.mark.parametrize( - "handler_name", all_internal_receiver_handlers, ids=all_receiver_handlers - ) - def test_handler_runners(self, inbox_manager, handler_name): - hm = AsyncHandlerManager(inbox_manager) - assert hm._receiver_handler_runners[handler_name] is None - - @pytest.mark.it("Initializes client event handler runner task reference to None") - def test_client_event_handler_runner(self, inbox_manager): - hm = AsyncHandlerManager(inbox_manager) - assert hm._client_event_runner is None - - -@pytest.mark.describe("AsyncHandlerManager - .stop()") -class TestStop(object): - @pytest.fixture( - params=[ - "No handlers running", - "Some receiver handlers running", - "Some client event handlers running", - "Some receiver and some client event handlers running", - "All handlers running", - ] - ) - def handler_manager(self, mocker, request, inbox_manager): - hm = AsyncHandlerManager(inbox_manager) - if request.param == "Some receiver handlers running": - # Set an arbitrary receiver handler - hm.on_message_received = mocker.MagicMock() - elif request.param == "Some client event handlers running": - # Set an arbitrary client event handler - hm.on_connection_state_change = mocker.MagicMock() - elif request.param == "Some receiver and some client event handlers running": - # Set an arbitrary receiver and client event handler - hm.on_message_received = mocker.MagicMock() - hm.on_connection_state_change = mocker.MagicMock() - elif request.param == "All handlers running": - for handler_name in all_handlers: - setattr(hm, handler_name, mocker.MagicMock()) - yield hm - hm.stop() - - @pytest.mark.it("Stops all currently running handlers") - def test_stops_all_runners(self, handler_manager): - handler_manager.stop() - for handler_name in all_internal_receiver_handlers: - assert handler_manager._receiver_handler_runners[handler_name] is None - assert handler_manager._client_event_runner is None - - @pytest.mark.it( - "Stops only the currently running receiver handlers if the 'receiver_handlers_only' parameter is True" - ) - def test_stop_only_receiver_handlers(self, handler_manager): - if handler_manager._client_event_runner is not None: - client_event_handlers_running = True - else: - client_event_handlers_running = False - - handler_manager.stop(receiver_handlers_only=True) - - # All receiver handlers have stopped - for handler_name in all_internal_receiver_handlers: - assert handler_manager._receiver_handler_runners[handler_name] is None - # If the client event handlers were running, they are STILL running - if client_event_handlers_running: - assert handler_manager._client_event_runner is not None - - @pytest.mark.it("Completes all pending handler invocations before stopping the runner(s)") - async def test_completes_pending(self, mocker, inbox_manager): - hm = AsyncHandlerManager(inbox_manager) - - # NOTE: We use two handlers arbitrarily here to show this happens for all handler runners - mock_msg_handler = ThreadsafeMock() - mock_mth_handler = ThreadsafeMock() - msg_inbox = inbox_manager.get_unified_message_inbox() - mth_inbox = inbox_manager.get_method_request_inbox() - for _ in range(200): # sufficiently many items so can't complete quickly - msg_inbox.put(mocker.MagicMock()) - mth_inbox.put(mocker.MagicMock()) - - hm.on_message_received = mock_msg_handler - hm.on_method_request_received = mock_mth_handler - assert mock_msg_handler.call_count < 200 - assert mock_mth_handler.call_count < 200 - hm.stop() - await asyncio.sleep(0.1) - assert mock_msg_handler.call_count == 200 - assert mock_mth_handler.call_count == 200 - assert msg_inbox.empty() - assert mth_inbox.empty() - - -@pytest.mark.describe("AsyncHandlerManager - .ensure_running()") -class TestEnsureRunning(object): - @pytest.fixture( - params=[ - "All handlers set, all stopped", - "All handlers set, receivers stopped, client events running", - "All handlers set, all running", - "Some receiver and client event handlers set, all stopped", - "Some receiver and client event handlers set, receivers stopped, client events running", - "Some receiver and client event handlers set, all running", - "Some receiver handlers set, all stopped", - "Some receiver handlers set, all running", - "Some client event handlers set, all stopped", - "Some client event handlers set, all running", - "No handlers set", - ] - ) - def handler_manager(self, mocker, request, inbox_manager): - hm = AsyncHandlerManager(inbox_manager) - - if request.param == "All handlers set, all stopped": - for handler_name in all_handlers: - setattr(hm, handler_name, handler) - hm.stop() - elif request.param == "All handlers set, receivers stopped, client events running": - for handler_name in all_handlers: - setattr(hm, handler_name, handler) - hm.stop(receiver_handlers_only=True) - elif request.param == "All handlers set, all running": - for handler_name in all_handlers: - setattr(hm, handler_name, handler) - elif request.param == "Some receiver and client event handlers set, all stopped": - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - hm.stop() - elif ( - request.param - == "Some receiver and client event handlers set, receivers stopped, client events running" - ): - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - hm.stop(receiver_handlers_only=True) - elif request.param == "Some receiver and client event handlers set, all running": - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - elif request.param == "Some receiver handlers set, all stopped": - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.stop() - elif request.param == "Some receiver handlers set, all running": - hm.on_message_received = handler - hm.on_method_request_received = handler - elif request.param == "Some client event handlers set, all stopped": - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - hm.stop() - elif request.param == "Some client event handlers set, all running": - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - - yield hm - hm.stop() - - @pytest.mark.it( - "Starts handler runners for any handler that is set, but does not have a handler runner running" - ) - def test_starts_runners_if_necessary(self, handler_manager): - handler_manager.ensure_running() - - # Check receiver handlers - for handler_name in all_receiver_handlers: - if getattr(handler_manager, handler_name) is not None: - # NOTE: this assumes the convention of internal names being the name of a handler - # prefixed with a "_". If this ever changes, you must change this test. - assert handler_manager._receiver_handler_runners["_" + handler_name] is not None - - # Check client event handlers - for handler_name in all_client_event_handlers: - if getattr(handler_manager, handler_name) is not None: - assert handler_manager._client_event_runner is not None - # don't need to check the rest of the handlers since they all share a runner - break - - -# ############## -# # PROPERTIES # -# ############## - - -class SharedHandlerPropertyTests(object): - @pytest.fixture - def handler_manager(self, inbox_manager): - hm = AsyncHandlerManager(inbox_manager) - yield hm - hm.stop() - - # NOTE: We use setattr() and getattr() in these tests so they're generic to all properties. - # This is functionally identical to doing explicit assignment to a property, it just - # doesn't read quite as well. - - @pytest.mark.it("Can be both read and written to") - async def test_read_write(self, handler_name, handler_manager, handler): - assert getattr(handler_manager, handler_name) is None - setattr(handler_manager, handler_name, handler) - assert getattr(handler_manager, handler_name) is handler - setattr(handler_manager, handler_name, None) - assert getattr(handler_manager, handler_name) is None - - -class SharedReceiverHandlerPropertyTests(SharedHandlerPropertyTests): - # NOTE: If there is ever any deviation in the convention of what the internal names of handlers - # are other than just a prefixed "_", we'll have to move this fixture to the child classes so - # it can be unique to each handler - @pytest.fixture - def handler_name_internal(self, handler_name): - return "_" + handler_name - - @pytest.mark.it( - "Creates and stores a Future for the corresponding handler runner when value is set to a function or coroutine handler" - ) - async def test_future_created( - self, handler_name, handler_name_internal, handler_manager, handler - ): - assert handler_manager._receiver_handler_runners[handler_name_internal] is None - setattr(handler_manager, handler_name, handler) - assert isinstance( - handler_manager._receiver_handler_runners[handler_name_internal], - concurrent.futures.Future, - ) - - @pytest.mark.it( - "Stops the corresponding handler runner and deletes any existing stored Future for it when the value is set back to None" - ) - async def test_future_removed( - self, handler_name, handler_name_internal, handler_manager, handler - ): - # Set handler - setattr(handler_manager, handler_name, handler) - # Future has been created and is active - fut = handler_manager._receiver_handler_runners[handler_name_internal] - assert isinstance(fut, concurrent.futures.Future) - assert not fut.done() - # Set the handler back to None - setattr(handler_manager, handler_name, None) - # Future has been completed, and the manager no longer has a reference to it - assert fut.done() - assert handler_manager._receiver_handler_runners[handler_name_internal] is None - - @pytest.mark.it( - "Does not delete, remove, or replace the Future for the corresponding handler runner when updated with a new function or coroutine value" - ) - async def test_future_unchanged_by_handler_update( - self, handler_name, handler_name_internal, handler_manager, handler - ): - # Set the handler - setattr(handler_manager, handler_name, handler) - # Future has been created and is active - future = handler_manager._receiver_handler_runners[handler_name_internal] - assert isinstance(future, concurrent.futures.Future) - assert not future.done() - - # Set new handler - def new_handler(arg): - pass - - setattr(handler_manager, handler_name, new_handler) - # Future has not completed, and is still maintained by the manager - assert handler_manager._receiver_handler_runners[handler_name_internal] is future - assert not future.done() - - @pytest.mark.it( - "Is invoked by the runner when the Inbox corresponding to the handler receives an object, passing that object to the handler" - ) - async def test_handler_invoked( - self, mocker, handler_name, handler_manager, handler, handler_checker, inbox - ): - # Set the handler - setattr(handler_manager, handler_name, handler) - # Handler has not been called - assert handler_checker.handler_called is False - assert handler_checker.handler_call_args is None - - # Add an item to the associated inbox, triggering the handler - mock_obj = mocker.MagicMock() - inbox.put(mock_obj) - await asyncio.sleep(0.1) - - # Handler has been called with the item from the inbox - assert handler_checker.handler_called is True - assert handler_checker.handler_call_args == (mock_obj,) - - @pytest.mark.it( - "Is invoked by the runner every time the Inbox corresponding to the handler receives an object" - ) - async def test_handler_invoked_multiple( - self, mocker, handler_name, handler_manager, handler, handler_checker, inbox - ): - # Set the handler - setattr(handler_manager, handler_name, handler) - # Handler has not been called - assert handler_checker.handler_call_count == 0 - - # Add 5 items to the associated inbox, triggering the handler - for _ in range(5): - inbox.put(mocker.MagicMock()) - await asyncio.sleep(0.1) - - # Handler has been called 5 times - assert handler_checker.handler_call_count == 5 - - @pytest.mark.it( - "Is invoked for every item already in the corresponding Inbox at the moment of handler removal" - ) - async def test_handler_resolve_pending_items_before_handler_removal( - self, mocker, handler_name, handler_manager, handler, handler_checker, inbox - ): - assert inbox.empty() - # Queue up a bunch of items in the inbox - for _ in range(100): - inbox.put(mocker.MagicMock()) - # The handler has not yet been called - assert handler_checker.handler_call_count == 0 - # Items are still in the inbox - assert not inbox.empty() - # Set the handler - setattr(handler_manager, handler_name, handler) - # The handler has not yet been called for everything that was in the inbox - # NOTE: I'd really like to show that the handler call count is also > 0 here, but - # it's pretty difficult to make the timing work - await asyncio.sleep(0.1) - handler_checker.handler_call_count < 100 - - # Immediately remove the handler - setattr(handler_manager, handler_name, None) - # Wait to give a chance for the handler runner to finish calling everything - await asyncio.sleep(0.1) - # Despite removal, handler has been called for everything that was in the inbox at the - # time of the removal - assert handler_checker.handler_call_count == 100 - assert inbox.empty() - - # Add some more items - for _ in range(100): - inbox.put(mocker.MagicMock()) - # Wait to give a chance for the handler to be called (it won't) - await asyncio.sleep(0.1) - # Despite more items added to inbox, no further handler calls have been made beyond the - # initial calls that were made when the original items were added - assert handler_checker.handler_call_count == 100 - - @pytest.mark.it( - "Sends a HandlerManagerException to the background exception handler if any exception is raised during its invocation" - ) - async def test_exception_in_handler( - self, mocker, handler_name, handler_manager, inbox, arbitrary_exception - ): - # NOTE: this test tests both coroutines and functions without the need for parametrization - background_exc_spy = mocker.spy(handle_exceptions, "handle_background_exception") - - def function_handler(arg): - raise arbitrary_exception - - async def coro_handler(arg): - raise arbitrary_exception - - # Set function handler - setattr(handler_manager, handler_name, function_handler) - # Background exception handler has not been called - assert background_exc_spy.call_count == 0 - # Add an item to corresponding inbox, triggering the handler - inbox.put(mocker.MagicMock()) - await asyncio.sleep(0.1) - # Background exception handler was called - assert background_exc_spy.call_count == 1 - e = background_exc_spy.call_args[0][0] - assert isinstance(e, HandlerManagerException) - assert e.__cause__ is arbitrary_exception - - # Clear the spy - background_exc_spy.reset_mock() - - # Set coroutine handler - setattr(handler_manager, handler_name, coro_handler) - # Background exception handler has not been called - assert background_exc_spy.call_count == 0 - # Add an item to corresponding inbox, triggering the handler - inbox.put(mocker.MagicMock()) - await asyncio.sleep(0.1) - # Background exception handler was called - assert background_exc_spy.call_count == 1 - e = background_exc_spy.call_args[0][0] - assert isinstance(e, HandlerManagerException) - assert e.__cause__ is arbitrary_exception - - @pytest.mark.it( - "Can be updated with a new value that the corresponding handler runner will immediately begin using for handler invocations instead" - ) - async def test_handler_update_handler(self, mocker, handler_name, handler_manager, inbox): - # NOTE: this test tests both coroutines and functions without the need for parametrization - mock_handler = mocker.MagicMock() - - async def handler2(arg): - # Invoking handler2 replaces the set handler with a mock - setattr(handler_manager, handler_name, mock_handler) - - def handler1(arg): - # Invoking handler1 replaces the set handler with handler2 - setattr(handler_manager, handler_name, handler2) - - setattr(handler_manager, handler_name, handler1) - - inbox.put(mocker.MagicMock()) - await asyncio.sleep(0.1) - # The set handler (handler1) has been replaced with a new handler (handler2) - assert getattr(handler_manager, handler_name) is not handler1 - assert getattr(handler_manager, handler_name) is handler2 - # Add a new item to the inbox - inbox.put(mocker.MagicMock()) - await asyncio.sleep(0.1) - # The set handler (handler2) has now been replaced by a mock handler - assert getattr(handler_manager, handler_name) is not handler2 - assert getattr(handler_manager, handler_name) is mock_handler - # Add a new item to the inbox - inbox.put(mocker.MagicMock()) - await asyncio.sleep(0.1) - # The mock was now called - assert getattr(handler_manager, handler_name).call_count == 1 - - -class SharedClientEventHandlerPropertyTests(SharedHandlerPropertyTests): - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_client_event_inbox() - - @pytest.mark.it( - "Creates and stores a Future for the Client Event handler runner when value is set to a function or coroutine if the Client Event handler runner does not already exist" - ) - async def test_no_client_event_runner(self, handler_name, handler_manager, handler): - assert handler_manager._client_event_runner is None - setattr(handler_manager, handler_name, handler) - t = handler_manager._client_event_runner - assert isinstance(t, concurrent.futures.Future) - - @pytest.mark.it( - "Does not modify the Client Event handler runner future when value is set to a function or coroutine if the Client Event handler runner already exists" - ) - async def test_client_event_runner_already_exists(self, handler_name, handler_manager, handler): - # Add a fake client event runner future - fake_runner_future = concurrent.futures.Future() - handler_manager._client_event_runner = fake_runner_future - # Set handler - setattr(handler_manager, handler_name, handler) - # Fake future was not changed - assert handler_manager._client_event_runner is fake_runner_future - # Clean up the future so tests don't hang - fake_runner_future.cancel() - handler_manager._client_event_runner = None - - @pytest.mark.it( - "Does not delete, remove or replace the Future for the Client Event handler runner when value is set back to None" - ) - async def test_handler_removed(self, handler_name, handler_manager, handler): - # Set handler - setattr(handler_manager, handler_name, handler) - # Future as been created and is active - future = handler_manager._client_event_runner - assert isinstance(future, concurrent.futures.Future) - assert not future.done() - # Set the handler back to None - setattr(handler_manager, handler_name, None) - # Future is still maintained on the manager and active - assert handler_manager._client_event_runner is future - assert not future.done() - - @pytest.mark.it( - "Does not delete, remove or replace the Future for the Client Event handler runner when updated with a new function or coroutine value" - ) - async def test_handler_update(self, handler_name, handler_manager, handler): - # Set handler - setattr(handler_manager, handler_name, handler) - # Future as been created and is active - future = handler_manager._client_event_runner - assert isinstance(future, concurrent.futures.Future) - assert not future.done() - - # Set new handler - def new_handler(arg): - pass - - setattr(handler_manager, handler_name, new_handler) - - # Future is still maintained on the manager and active - assert handler_manager._client_event_runner is future - assert not future.done() - - @pytest.mark.it( - "Is invoked by the runner only when the Client Event Inbox receives a matching event, passing any arguments to the handler" - ) - async def test_handler_invoked( - self, mocker, handler_name, handler_manager, handler_checker, handler, inbox, event - ): - # Set the handler - setattr(handler_manager, handler_name, handler) - # Handler has not been called - assert handler_checker.handler_called is False - - # Add the event to the client inbox - inbox.put(event) - await asyncio.sleep(0.1) - - # Handler has been called with the arguments from the event - assert handler_checker.handler_call_count == 1 - assert handler_checker.handler_call_args == event.args_for_user - - # Add a non-matching event ot the client event inbox - non_matching_event = client_event.ClientEvent("NON_MATCHING_EVENT") - inbox.put(non_matching_event) - await asyncio.sleep(0.1) - - # Handler has not been called again - assert handler_checker.handler_call_count == 1 - - @pytest.mark.it( - "Is invoked by the runner every time the Client Event Inbox receives a matching Client Event" - ) - async def test_handler_invoked_multiple( - self, handler_name, handler_manager, handler, handler_checker, inbox, event - ): - # Set the handler - setattr(handler_manager, handler_name, handler) - # Handler has not been called - assert handler_checker.handler_call_count == 0 - - # Add 5 items to the corresponding inbox, triggering the handler - for _ in range(5): - inbox.put(event) - await asyncio.sleep(0.1) - - # Handler has been called 5 times - assert handler_checker.handler_call_count == 5 - - @pytest.mark.it( - "Sends a HandlerManagerException to the background exception handler if any exception is raised during its invocation" - ) - async def test_exception_in_handler( - self, mocker, handler_name, handler_manager, inbox, event, arbitrary_exception - ): - # NOTE: this test tests both coroutines and functions without the need for parametrization - background_exc_spy = mocker.spy(handle_exceptions, "handle_background_exception") - - def function_handler(*args): - raise arbitrary_exception - - async def coro_handler(*args): - raise arbitrary_exception - - # Set function handler - setattr(handler_manager, handler_name, function_handler) - # Background exception handler has not been called - assert background_exc_spy.call_count == 0 - # Add an item to corresponding inbox, triggering the handler - inbox.put(event) - await asyncio.sleep(0.1) - # Background exception handler was called - assert background_exc_spy.call_count == 1 - e = background_exc_spy.call_args[0][0] - assert isinstance(e, HandlerManagerException) - assert e.__cause__ is arbitrary_exception - - # Clear the spy - background_exc_spy.reset_mock() - - # Set coroutine handler - setattr(handler_manager, handler_name, coro_handler) - # Background exception handler has not been called - assert background_exc_spy.call_count == 0 - # Add an item to corresponding inbox, triggering the handler - inbox.put(event) - await asyncio.sleep(0.1) - # Background exception handler was called - assert background_exc_spy.call_count == 1 - e = background_exc_spy.call_args[0][0] - assert isinstance(e, HandlerManagerException) - assert e.__cause__ is arbitrary_exception - - @pytest.mark.it( - "Can be updated with a new value that the corresponding handler runner will immediately begin using for handler invocations instead" - ) - async def test_handler_update_handler( - self, mocker, handler_name, handler_manager, inbox, event - ): - # NOTE: this test tests both coroutines and functions without the need for parametrization - mock_handler = mocker.MagicMock() - - async def handler2(*args): - # Invoking handler2 replaces the set handler with a mock - setattr(handler_manager, handler_name, mock_handler) - - def handler1(*args): - # Invoking handler1 replaces the set handler with handler2 - setattr(handler_manager, handler_name, handler2) - - setattr(handler_manager, handler_name, handler1) - - inbox.put(event) - await asyncio.sleep(0.1) - # The set handler (handler1) has been replaced with a new handler (handler2) - assert getattr(handler_manager, handler_name) is not handler1 - assert getattr(handler_manager, handler_name) is handler2 - # Add a new item to the inbox - inbox.put(event) - await asyncio.sleep(0.1) - # The set handler (handler2) has now been replaced by a mock handler - assert getattr(handler_manager, handler_name) is not handler2 - assert getattr(handler_manager, handler_name) is mock_handler - # Add a new item to the inbox - inbox.put(event) - await asyncio.sleep(0.1) - # The mock was now called - assert getattr(handler_manager, handler_name).call_count == 1 - - -@pytest.mark.describe("AsyncHandlerManager - PROPERTY: .on_message_received") -class TestAsyncHandlerManagerPropertyOnMessageReceived(SharedReceiverHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_message_received" - - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_unified_message_inbox() - - -@pytest.mark.describe("AsyncHandlerManager - PROPERTY: .on_method_request_received") -class TestAsyncHandlerManagerPropertyOnMethodRequestReceived(SharedReceiverHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_method_request_received" - - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_method_request_inbox() - - -@pytest.mark.describe("AsyncHandlerManager - PROPERTY: .on_twin_desired_properties_patch_received") -class TestAsyncHandlerManagerPropertyOnTwinDesiredPropertiesPatchReceived( - SharedReceiverHandlerPropertyTests -): - @pytest.fixture - def handler_name(self): - return "on_twin_desired_properties_patch_received" - - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_twin_patch_inbox() - - -@pytest.mark.describe("AsyncHandlerManager - PROPERTY: .on_connection_state_change") -class TestAsyncHandlerManagerPropertyOnConnectionStateChange(SharedClientEventHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_connection_state_change" - - @pytest.fixture - def event(self): - return client_event.ClientEvent(client_event.CONNECTION_STATE_CHANGE) - - -@pytest.mark.describe("AsyncHandlerManager - PROPERTY: .on_new_sastoken_required") -class TestAsyncHandlerManagerPropertyOnNewSastokenRequired(SharedClientEventHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_new_sastoken_required" - - @pytest.fixture - def event(self): - return client_event.ClientEvent(client_event.NEW_SASTOKEN_REQUIRED) - - -@pytest.mark.describe("AsyncHandlerManager - PROPERTY: .on_background_exception") -class TestAsyncHandlerManagerPropertyOnBackgroundException(SharedClientEventHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_background_exception" - - @pytest.fixture - def event(self, arbitrary_exception): - return client_event.ClientEvent(client_event.BACKGROUND_EXCEPTION, arbitrary_exception) - - -@pytest.mark.describe("AsyncHandlerManager - PROPERTY: .handling_client_events") -class TestAsyncHandlerManagerPropertyHandlingClientEvents(object): - @pytest.fixture - def handler_manager(self, inbox_manager): - hm = AsyncHandlerManager(inbox_manager) - yield hm - hm.stop() - - @pytest.mark.it("Is True if the Client Event Handler Runner is running") - async def test_client_event_runner_running(self, handler_manager): - # Add a fake client event runner thread - fake_runner_future = concurrent.futures.Future() - handler_manager._client_event_runner = fake_runner_future - - assert handler_manager.handling_client_events is True - - # Clean up the future so tests don't hang - fake_runner_future.cancel() - handler_manager._client_event_runner = None - - @pytest.mark.it("Is False if the Client Event Handler Runner is not running") - async def test_client_event_runner_not_running(self, handler_manager): - assert handler_manager._client_event_runner is None - assert handler_manager.handling_client_events is False diff --git a/tests/unit/iothub/aio/test_async_inbox.py b/tests/unit/iothub/aio/test_async_inbox.py deleted file mode 100644 index b98055fdc..000000000 --- a/tests/unit/iothub/aio/test_async_inbox.py +++ /dev/null @@ -1,125 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import asyncio -import logging -from azure.iot.device.iothub.aio.async_inbox import AsyncClientInbox - -logging.basicConfig(level=logging.DEBUG) - -# Note that some small delays need to be added at the end of async tests due to -# RuntimeWarnings being thrown by the test ending before janus can correctly -# resolve its Futures. This may be a bug in janus. - - -@pytest.mark.describe("AsyncClientInbox") -class TestAsyncClientInbox(object): - @pytest.mark.it("Instantiates empty") - def test_instantiates_empty(self): - inbox = AsyncClientInbox() - assert inbox.empty() - - @pytest.mark.it("Can be checked regarding whether or not it contains an item") - def test_check_item_is_in_inbox(self, mocker): - inbox = AsyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - assert item not in inbox - inbox.put(item) - assert item in inbox - - @pytest.mark.it("Can be checked regarding whether or not it is empty") - @pytest.mark.asyncio - async def test_can_check_if_empty(self, mocker): - inbox = AsyncClientInbox() - assert inbox.empty() - inbox.put(mocker.MagicMock()) - assert not inbox.empty() - await inbox.get() - assert inbox.empty() - await asyncio.sleep(0.01) # Do this to prevent RuntimeWarning from janus - - @pytest.mark.it("Operates according to FIFO") - @pytest.mark.asyncio - async def test_operates_according_to_FIFO(self, mocker): - inbox = AsyncClientInbox() - item1 = mocker.MagicMock() - item2 = mocker.MagicMock() - item3 = mocker.MagicMock() - inbox.put(item1) - inbox.put(item2) - inbox.put(item3) - - assert await inbox.get() is item1 - assert await inbox.get() is item2 - assert await inbox.get() is item3 - - await asyncio.sleep(0.01) # Do this to prevent RuntimeWarning from janus - - -@pytest.mark.describe("AsyncClientInbox - .put()") -class TestAsyncClientInboxPut(object): - @pytest.mark.it("Adds the given item to the inbox") - def test_adds_item_to_inbox(self, mocker): - inbox = AsyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - inbox.put(item) - assert not inbox.empty() - assert item in inbox - - -@pytest.mark.describe("AsyncClientInbox - .get()") -@pytest.mark.asyncio -class TestAsyncClientInboxGet(object): - @pytest.mark.it("Returns and removes the next item from the inbox, if there is one") - async def test_removes_item_from_inbox_if_already_there(self, mocker): - inbox = AsyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - inbox.put(item) - assert not inbox.empty() - retrieved_item = await inbox.get() - assert retrieved_item is item - assert inbox.empty() - - await asyncio.sleep(0.01) # Do this to prevent RuntimeWarning from janus - - @pytest.mark.it( - "Blocks on an empty inbox until an item is available to remove and return, if using blocking mode" - ) - async def test_get_waits_for_item_to_be_added_if_inbox_empty(self, mocker): - inbox = AsyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - - async def wait_for_item(): - retrieved_item = await inbox.get() - assert retrieved_item is item - - async def insert_item(): - await asyncio.sleep(1) # wait before adding item to ensure the above coroutine is first - inbox.put(item) - - await asyncio.gather(wait_for_item(), insert_item()) - - -@pytest.mark.describe("AsyncClientInbox - .clear()") -class TestAsyncClientInboxClear(object): - @pytest.mark.it("Clears all items from the inbox") - def test_can_clear_all_items(self, mocker): - inbox = AsyncClientInbox() - item1 = mocker.MagicMock() - item2 = mocker.MagicMock() - item3 = mocker.MagicMock() - inbox.put(item1) - inbox.put(item2) - inbox.put(item3) - assert not inbox.empty() - - inbox.clear() - assert inbox.empty() diff --git a/tests/unit/iothub/aio/test_loop_management.py b/tests/unit/iothub/aio/test_loop_management.py deleted file mode 100644 index 2d7af158a..000000000 --- a/tests/unit/iothub/aio/test_loop_management.py +++ /dev/null @@ -1,71 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import asyncio -import logging -from azure.iot.device.iothub.aio import loop_management - -logging.basicConfig(level=logging.DEBUG) - - -class SharedCustomLoopTests(object): - @pytest.fixture(autouse=True) - def setup_teardown(self): - # Run cleanup both before and after tests so that the changes made here do not - # impact other test modules when the tests are run as a complete suite - loop_management._cleanup() - yield - loop_management._cleanup() - - @pytest.mark.it("Returns a new event loop the first time it is called") - def test_new_loop(self, mocker, fn_under_test): - new_event_loop_mock = mocker.patch.object(asyncio, "new_event_loop") - loop = fn_under_test() - assert loop is new_event_loop_mock.return_value - - @pytest.mark.it("Begins running the new event loop in a daemon Thread") - def test_daemon_thread(self, mocker, fn_under_test): - mock_new_event_loop = mocker.patch("asyncio.new_event_loop") - mock_loop = mock_new_event_loop.return_value - mock_thread_init = mocker.patch("threading.Thread") - mock_thread = mock_thread_init.return_value - fn_under_test() - # Loop was created - assert mock_new_event_loop.call_count == 1 - # Loop is running on the new Thread - assert mock_thread_init.call_count == 1 - assert mock_thread_init.call_args == mocker.call(target=mock_loop.run_forever) - assert mock_thread.start.call_count == 1 - # Thread is a daemon - assert mock_thread.daemon is True - - @pytest.mark.it("Returns the same event loop each time it is called") - def test_same_loop(self, fn_under_test): - loop1 = fn_under_test() - loop2 = fn_under_test() - assert loop1 is loop2 - - -@pytest.mark.describe(".get_client_internal_loop()") -class TestGetClientInternalLoop(SharedCustomLoopTests): - @pytest.fixture - def fn_under_test(self): - return loop_management.get_client_internal_loop - - -@pytest.mark.describe(".get_client_handler_runner_loop()") -class TestGetClientHandlerRunnerLoop(SharedCustomLoopTests): - @pytest.fixture - def fn_under_test(self): - return loop_management.get_client_handler_runner_loop - - -@pytest.mark.describe(".get_client_handler_loop()") -class TestGetClientHandlerLoop(SharedCustomLoopTests): - @pytest.fixture - def fn_under_test(self): - return loop_management.get_client_handler_loop diff --git a/tests/unit/iothub/client_fixtures.py b/tests/unit/iothub/client_fixtures.py deleted file mode 100644 index 072e7a986..000000000 --- a/tests/unit/iothub/client_fixtures.py +++ /dev/null @@ -1,288 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import time -import urllib -from azure.iot.device.iothub.models import Message, MethodResponse, MethodRequest -from azure.iot.device.common.models.x509 import X509 - - -"""---Constants---""" - -shared_access_key = "Zm9vYmFy" -hostname = "hostname.azure-net" -device_id = "MyDevice" -module_id = "MyModule" -gateway_hostname = "MyGatewayHostname" -signature = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" -expiry = str(int(time.time()) + 3600) -fake_x509_cert_value = "fake_certificate_value" -fake_x509_cert_key = "fake_certificate_key" -fake_pass_phrase = "fake_pass_phrase" - - -"""----Shared model fixtures----""" - - -@pytest.fixture -def message(): - return Message("Message Payload") - - -@pytest.fixture -def method_response(): - return MethodResponse(request_id="1", status=200, payload={"key": "value"}) - - -@pytest.fixture -def method_request(): - return MethodRequest(request_id="1", name="some_method", payload={"key": "value"}) - - -"""----Shared Twin fixtures----""" - - -@pytest.fixture -def twin_patch_desired(): - return {"properties": {"desired": {"foo": 1}}} - - -@pytest.fixture -def twin_patch_reported(): - return {"properties": {"reported": {"bar": 2}}} - - -@pytest.fixture -def fake_twin(): - return {"fake_twin": True} - - -"""----Shared connection string fixtures----""" - -device_connection_string_format = ( - "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key}" -) -device_connection_string_gateway_format = "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}" - -module_connection_string_format = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}" -module_connection_string_gateway_format = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}" - - -@pytest.fixture(params=["Device Connection String", "Device Connection String w/ Protocol Gateway"]) -def device_connection_string(request): - string_type = request.param - if string_type == "Device Connection String": - return device_connection_string_format.format( - hostname=hostname, device_id=device_id, shared_access_key=shared_access_key - ) - else: - return device_connection_string_gateway_format.format( - hostname=hostname, - device_id=device_id, - shared_access_key=shared_access_key, - gateway_hostname=gateway_hostname, - ) - - -@pytest.fixture(params=["Module Connection String", "Module Connection String w/ Protocol Gateway"]) -def module_connection_string(request): - string_type = request.param - if string_type == "Module Connection String": - return module_connection_string_format.format( - hostname=hostname, - device_id=device_id, - module_id=module_id, - shared_access_key=shared_access_key, - ) - else: - return module_connection_string_gateway_format.format( - hostname=hostname, - device_id=device_id, - module_id=module_id, - shared_access_key=shared_access_key, - gateway_hostname=gateway_hostname, - ) - - -"""----Shared SAS fixtures---""" - -sas_token_format = "SharedAccessSignature sr={uri}&sig={signature}&se={expiry}" -# when to use the skn format? -sas_token_skn_format = ( - "SharedAccessSignature sr={uri}&sig={signature}&se={expiry}&skn={shared_access_key_name}" -) - -# what about variant input with different ordered attributes -# SharedAccessSignature sig={signature-string}&se={expiry}&skn={policyName}&sr={URL-encoded-resourceURI} - - -@pytest.fixture() -def device_sas_token_string(): - uri = hostname + "/devices/" + device_id - return sas_token_format.format( - uri=urllib.parse.quote(uri, safe=""), - signature=urllib.parse.quote(signature, safe=""), - expiry=expiry, - ) - - -@pytest.fixture() -def module_sas_token_string(): - uri = hostname + "/devices/" + device_id + "/modules/" + module_id - return sas_token_format.format( - uri=urllib.parse.quote(uri, safe=""), - signature=urllib.parse.quote(signature, safe=""), - expiry=expiry, - ) - - -"""----Shared certificate fixtures----""" - - -@pytest.fixture() -def x509(): - return X509(fake_x509_cert_value, fake_x509_cert_key, fake_pass_phrase) - - -"""----Shared Edge Container configuration---""" - - -@pytest.fixture() -def edge_container_environment(): - return { - "IOTEDGE_MODULEID": "__FAKE_MODULE_ID__", - "IOTEDGE_DEVICEID": "__FAKE_DEVICE_ID__", - "IOTEDGE_IOTHUBHOSTNAME": "__FAKE_HOSTNAME__", - "IOTEDGE_GATEWAYHOSTNAME": "__FAKE_GATEWAY_HOSTNAME__", - "IOTEDGE_APIVERSION": "__FAKE_API_VERSION__", - "IOTEDGE_MODULEGENERATIONID": "__FAKE_MODULE_GENERATION_ID__", - "IOTEDGE_WORKLOADURI": "http://__FAKE_WORKLOAD_URI__/", - } - - -@pytest.fixture() -def edge_local_debug_environment(): - cs = module_connection_string_gateway_format.format( - hostname=hostname, - device_id=device_id, - module_id=module_id, - shared_access_key=shared_access_key, - gateway_hostname=gateway_hostname, - ) - return { - "EdgeHubConnectionString": cs, - "EdgeModuleCACertificateFile": "__FAKE_SERVER_VERIFICATION_CERTIFICATE__", - } - - -@pytest.fixture -def mock_edge_hsm(mocker): - mock_edge_hsm = mocker.patch("azure.iot.device.iothub.edge_hsm.IoTEdgeHsm") - mock_edge_hsm.return_value.sign.return_value = "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=" - mock_edge_hsm.return_value.get_certificate.return_value = ( - "__FAKE_SERVER_VERIFICATION_CERTIFICATE__" - ) - return mock_edge_hsm - - -"""----Shared mock pipeline fixture----""" - - -class FakeIoTHubPipeline: - def __init__(self): - self.feature_enabled = {} # This just has to be here for the spec - - def shutdown(self, callback): - callback() - - def connect(self, callback): - callback() - - def disconnect(self, callback): - callback() - - def reauthorize_connection(self, callback): - callback() - - def enable_feature(self, feature_name, callback): - callback() - - def disable_feature(self, feature_name, callback): - callback() - - def send_message(self, event, callback): - callback() - - def send_output_message(self, event, callback): - callback() - - def send_method_response(self, method_response, callback): - callback() - - def get_twin(self, callback): - callback(twin={}) - - def patch_twin_reported_properties(self, patch, callback): - callback() - - -class FakeHTTPPipeline: - def __init__(self): - pass - - def invoke_method(self, device_id, method_params, callback, module_id=None): - callback(invoke_method_response="__fake_method_response__") - - def get_storage_info_for_blob(self, blob_name, callback): - callback(storage_info="__fake_storage_info__") - - def notify_blob_upload_status( - self, correlation_id, is_success, status_code, status_description, callback - ): - callback() - - -@pytest.fixture -def mqtt_pipeline(mocker): - """This fixture will automatically handle callbacks and should be - used in the majority of tests. - """ - return mocker.MagicMock(wraps=FakeIoTHubPipeline()) - - -@pytest.fixture -def mqtt_pipeline_manual_cb(mocker): - """This fixture is for use in tests where manual triggering of a - callback is required - """ - return mocker.MagicMock() - - -@pytest.fixture -def http_pipeline(mocker): - """This fixture will automatically handle callbacks and should be - used in the majority of tests - """ - return mocker.MagicMock(wraps=FakeHTTPPipeline()) - - -@pytest.fixture -def http_pipeline_manual_cb(mocker): - """This fixture is for use in tests where manual triggering of a - callback is required - """ - return mocker.MagicMock() - - -@pytest.fixture -def mock_mqtt_pipeline_init(mocker): - return mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") - - -@pytest.fixture -def mock_http_pipeline_init(mocker): - return mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") diff --git a/tests/unit/iothub/conftest.py b/tests/unit/iothub/conftest.py deleted file mode 100644 index c457fbdaf..000000000 --- a/tests/unit/iothub/conftest.py +++ /dev/null @@ -1,29 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -# These fixtures are shared between sync and async clients -from .client_fixtures import ( # noqa: F401 - message, - method_response, - method_request, - twin_patch_desired, - twin_patch_reported, - fake_twin, - mqtt_pipeline, - mqtt_pipeline_manual_cb, - http_pipeline, - http_pipeline_manual_cb, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - device_connection_string, - module_connection_string, - device_sas_token_string, - module_sas_token_string, - edge_container_environment, - edge_local_debug_environment, - x509, - mock_edge_hsm, -) diff --git a/tests/unit/iothub/models/test_message.py b/tests/unit/iothub/models/test_message.py deleted file mode 100644 index d444a55fd..000000000 --- a/tests/unit/iothub/models/test_message.py +++ /dev/null @@ -1,107 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.iothub.models import Message -from azure.iot.device import constant - -logging.basicConfig(level=logging.DEBUG) - -data_str = "Some string of data" -data_int = 987 -data_obj = Message(data_str) - - -@pytest.mark.describe("Message") -class TestMessage(object): - @pytest.mark.it("Instantiates from data type") - @pytest.mark.parametrize( - "data", [data_str, data_int, data_obj], ids=["String", "Integer", "Message"] - ) - def test_instantiates_from_data(self, data): - msg = Message(data) - assert msg.data == data - - @pytest.mark.it("Instantiates with optional provided message id") - def test_instantiates_with_optional_message_id(self): - message_id = "Postage12323" - msg = Message("some message", message_id) - assert msg.message_id == message_id - - @pytest.mark.it("Instantiates with optional provided content type and content encoding") - def test_instantiates_with_optional_contenttype_encoding(self): - ctype = "application/json" - encoding = "utf-16" - msg = Message("some message", None, encoding, ctype) - assert msg.content_encoding == encoding - assert msg.content_type == ctype - - @pytest.mark.it("Instantiates with optional provided output name") - def test_instantiates_with_optional_output_name(self): - output_name = "some_output" - msg = Message("some message", output_name=output_name) - assert msg.output_name == output_name - - @pytest.mark.it("Instantiates with no custom properties set") - def test_default_custom_properties(self): - msg = Message("some message") - assert msg.custom_properties == {} - - @pytest.mark.it("Instantiates with no set expiry time") - def test_default_expiry_time(self): - msg = Message("some message") - assert msg.expiry_time_utc is None - - @pytest.mark.it("Instantiates with no set correlation id") - def test_default_corr_id(self): - msg = Message("some message") - assert msg.correlation_id is None - - @pytest.mark.it("Instantiates with no set user id") - def test_default_user_id(self): - msg = Message("some message") - assert msg.user_id is None - - @pytest.mark.it("Instantiates with no set input name") - def test_default_input_name(self): - msg = Message("some message") - assert msg.input_name is None - - @pytest.mark.it("Instantiates with no set ack value") - def test_default_ack(self): - msg = Message("some message") - assert msg.ack is None - - @pytest.mark.it("Instantiates with no set iothub_interface_id (i.e. not as a security message)") - def test_default_security_msg_status(self): - msg = Message("some message") - assert msg.iothub_interface_id is None - - @pytest.mark.it("Maintains iothub_interface_id (security message) as a read-only property") - def test_read_only_iothub_interface_id(self): - msg = Message("some message") - with pytest.raises(AttributeError): - msg.iothub_interface_id = "value" - - @pytest.mark.it( - "Uses string representation of data/payload attribute as string representation of Message" - ) - @pytest.mark.parametrize( - "data", [data_str, data_int, data_obj], ids=["String", "Integer", "Message"] - ) - def test_str_rep(self, data): - msg = Message(data) - assert str(msg) == str(data) - - @pytest.mark.it("Can be set as a security message via API") - def test_setting_message_as_security_message(self): - ctype = "application/json" - encoding = "utf-16" - msg = Message("some message", None, encoding, ctype) - assert msg.iothub_interface_id is None - msg.set_as_security_message() - assert msg.iothub_interface_id == constant.SECURITY_MESSAGE_INTERFACE_ID diff --git a/tests/unit/iothub/models/test_methods.py b/tests/unit/iothub/models/test_methods.py deleted file mode 100644 index 5b5415d00..000000000 --- a/tests/unit/iothub/models/test_methods.py +++ /dev/null @@ -1,117 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.iothub.models import MethodRequest, MethodResponse - -logging.basicConfig(level=logging.DEBUG) - -dummy_rid = 1 -dummy_name = "name" -dummy_payload = {"MethodPayload": "somepayload"} -dummy_status = 200 - - -@pytest.mark.describe("MethodRequest - Instantiation") -class TestMethodRequest(object): - @pytest.mark.it("Instantiates with a read-only 'request_id' attribute") - def test_request_id_property_is_read_only(self): - m_req = MethodRequest(request_id=dummy_rid, name=dummy_name, payload=dummy_payload) - new_rid = 2 - - with pytest.raises(AttributeError): - m_req.request_id = new_rid - assert m_req.request_id != new_rid - assert m_req.request_id == dummy_rid - - @pytest.mark.it("Instantiates with a read-only 'name' attribute") - def test_name_property_is_read_only(self): - m_req = MethodRequest(request_id=dummy_rid, name=dummy_name, payload=dummy_payload) - new_name = "new_name" - - with pytest.raises(AttributeError): - m_req.name = new_name - assert m_req.name != new_name - assert m_req.name == dummy_name - - @pytest.mark.it("Instantiates with a read-only 'payload' attribute") - def test_payload_property_is_read_only(self): - m_req = MethodRequest(request_id=dummy_rid, name=dummy_name, payload=dummy_payload) - new_payload = {"NewPayload": "somenewpayload"} - - with pytest.raises(AttributeError): - m_req.payload = new_payload - assert m_req.payload != new_payload - assert m_req.payload == dummy_payload - - -@pytest.mark.describe("MethodResponse - Instantiation") -class TestMethodResponseInstantiation(object): - @pytest.mark.it("Instantiates with an editable 'request_id' attribute") - def test_instantiates_with_request_id(self): - response = MethodResponse(request_id=dummy_rid, status=dummy_status, payload=dummy_payload) - assert response.request_id == dummy_rid - - new_rid = "2" - assert response.request_id != new_rid - response.request_id = new_rid - assert response.request_id == new_rid - - @pytest.mark.it("Instantiates with an editable 'status' attribute") - def test_instantiates_with_status(self): - response = MethodResponse(request_id=dummy_rid, status=dummy_status, payload=dummy_payload) - assert response.status == dummy_status - - new_status = 400 - assert response.status != new_status - response.status = new_status - assert response.status == new_status - - @pytest.mark.it("Instantiates with an editable 'payload' attribute") - def test_instantiates_with_payload(self): - response = MethodResponse(request_id=dummy_rid, status=dummy_status, payload=dummy_payload) - assert response.payload == dummy_payload - - new_payload = {"NewPayload": "yes_this_is_new"} - assert response.payload != new_payload - response.payload = new_payload - assert response.payload == new_payload - - @pytest.mark.it("Instantiates with a default 'payload' of 'None' if not provided") - def test_instantiates_without_payload(self): - response = MethodResponse(request_id=dummy_rid, status=dummy_status) - assert response.request_id == dummy_rid - assert response.status == dummy_status - assert response.payload is None - - -@pytest.mark.describe("MethodResponse - .create_from_method_request()") -class TestMethodResponseCreateFromMethodRequest(object): - @pytest.mark.it("Instantiates using a MethodRequest to provide the 'request_id'") - def test_instantiates_from_method_request(self): - request = MethodRequest(request_id=dummy_rid, name=dummy_name, payload=dummy_payload) - status = 200 - payload = {"ResponsePayload": "SomeResponse"} - response = MethodResponse.create_from_method_request( - method_request=request, status=status, payload=payload - ) - - assert isinstance(response, MethodResponse) - assert response.request_id == request.request_id - assert response.status == status - assert response.payload == payload - - @pytest.mark.it("Instantiates with a default 'payload' of 'None' if not provided") - def test_instantiates_without_payload(self): - request = MethodRequest(request_id=dummy_rid, name=dummy_name, payload=dummy_payload) - status = 200 - response = MethodResponse.create_from_method_request(request, status) - - assert isinstance(response, MethodResponse) - assert response.request_id == request.request_id - assert response.status == status - assert response.payload is None diff --git a/tests/unit/iothub/pipeline/__init__.py b/tests/unit/iothub/pipeline/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/iothub/pipeline/conftest.py b/tests/unit/iothub/pipeline/conftest.py deleted file mode 100644 index fee736bb8..000000000 --- a/tests/unit/iothub/pipeline/conftest.py +++ /dev/null @@ -1,33 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -from tests.unit.common.pipeline.fixtures import ( # noqa: F401 - fake_pipeline_thread, - fake_non_pipeline_thread, - arbitrary_op, - arbitrary_event, - pipeline_connected_mock, - nucleus, -) - -from azure.iot.device.iothub.pipeline import constant - -# Update this list with features as they are added to the SDK -# NOTE: should this be refactored into a fixture so it doesn't have to be imported? -# Is this used anywhere that DOESN'T just turn it into a fixture? -all_features = [ - constant.C2D_MSG, - constant.INPUT_MSG, - constant.METHODS, - constant.TWIN, - constant.TWIN_PATCHES, -] - - -@pytest.fixture(params=all_features) -def iothub_pipeline_feature(request): - return request.param diff --git a/tests/unit/iothub/pipeline/test_config.py b/tests/unit/iothub/pipeline/test_config.py deleted file mode 100644 index 3ed31c994..000000000 --- a/tests/unit/iothub/pipeline/test_config.py +++ /dev/null @@ -1,81 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -from tests.unit.common.pipeline.config_test import PipelineConfigInstantiationTestBase -from azure.iot.device.iothub.pipeline.config import IoTHubPipelineConfig - -device_id = "my_device" -module_id = "my_module" -hostname = "hostname.some-domain.net" -product_info = "some_info" - - -@pytest.mark.describe("IoTHubPipelineConfig - Instantiation") -class TestIoTHubPipelineConfigInstantiation(PipelineConfigInstantiationTestBase): - - # This fixture is needed for tests inherited from the parent class - @pytest.fixture - def config_cls(self): - return IoTHubPipelineConfig - - # This fixture is needed for tests inherited from the parent class - @pytest.fixture - def required_kwargs(self): - return {"device_id": device_id, "hostname": hostname} - - # The parent class defines the auth mechanism fixtures (sastoken, x509). - # For the sake of ease of testing, we will assume sastoken is being used unless - # there is a strict need to do something else. - # It does not matter which is used for the purposes of these tests. - - @pytest.mark.it( - "Instantiates with the 'device_id' attribute set to the provided 'device_id' parameter" - ) - def test_device_id_set(self, sastoken): - config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) - assert config.device_id == device_id - - @pytest.mark.it( - "Instantiates with the 'module_id' attribute set to the provided 'module_id' parameter" - ) - def test_module_id_set(self, sastoken): - config = IoTHubPipelineConfig( - device_id=device_id, module_id=module_id, hostname=hostname, sastoken=sastoken - ) - assert config.module_id == module_id - - @pytest.mark.it( - "Instantiates with the 'module_id' attribute set to 'None' if no 'module_id' parameter is provided" - ) - def test_module_id_default(self, sastoken): - config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) - assert config.module_id is None - - @pytest.mark.it( - "Instantiates with the 'product_info' attribute set to the provided 'product_info' parameter" - ) - def test_product_info_set(self, sastoken): - config = IoTHubPipelineConfig( - device_id=device_id, hostname=hostname, product_info=product_info, sastoken=sastoken - ) - assert config.product_info == product_info - - @pytest.mark.it( - "Instantiates with the 'product_info' attribute defaulting to empty string if no 'product_info' parameter is provided" - ) - def test_product_info_default(self, sastoken): - config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) - assert config.product_info == "" - - @pytest.mark.it("Instantiates with the 'blob_upload' attribute set to False") - def test_blob_upload(self, sastoken): - config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) - assert config.blob_upload is False - - @pytest.mark.it("Instantiates with the 'method_invoke' attribute set to False") - def test_method_invoke(self, sastoken): - config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) - assert config.method_invoke is False diff --git a/tests/unit/iothub/pipeline/test_http_pipeline.py b/tests/unit/iothub/pipeline/test_http_pipeline.py deleted file mode 100644 index fe16203b5..000000000 --- a/tests/unit/iothub/pipeline/test_http_pipeline.py +++ /dev/null @@ -1,348 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.common.pipeline import ( - pipeline_stages_base, - pipeline_stages_http, - pipeline_ops_base, - pipeline_nucleus, -) -from azure.iot.device.iothub.pipeline import ( - pipeline_stages_iothub_http, - pipeline_ops_iothub_http, -) -from azure.iot.device.iothub.pipeline import HTTPPipeline - -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - -fake_device_id = "__fake_device_id__" -fake_module_id = "__fake_module_id__" -fake_blob_name = "__fake_blob_name__" - - -@pytest.fixture -def pipeline_configuration(mocker): - mocked_configuration = mocker.MagicMock() - mocked_configuration.blob_upload = True - mocked_configuration.method_invoke = True - mocked_configuration.sastoken.ttl = 1232 # set for compat - return mocked_configuration - - -@pytest.fixture -def pipeline(mocker, pipeline_configuration): - pipeline = HTTPPipeline(pipeline_configuration) - mocker.patch.object(pipeline._pipeline, "run_op") - return pipeline - - -@pytest.fixture -def twin_patch(): - return {"key": "value"} - - -# automatically mock the transport for all tests in this file. -@pytest.fixture(autouse=True) -def mock_transport(mocker): - mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_http.HTTPTransport", autospec=True - ) - - -@pytest.mark.describe("HTTPPipeline - Instantiation") -class TestHTTPPipelineInstantiation(object): - @pytest.mark.it("Configures the pipeline with a PipelineNucleus") - def test_pipeline_nucleus(self, pipeline_configuration): - pipeline = HTTPPipeline(pipeline_configuration) - - assert isinstance(pipeline._nucleus, pipeline_nucleus.PipelineNucleus) - assert pipeline._nucleus.pipeline_configuration is pipeline_configuration - - @pytest.mark.it("Configures the pipeline with a series of PipelineStages") - def test_pipeline_stages(self, pipeline_configuration): - pipeline = HTTPPipeline(pipeline_configuration) - curr_stage = pipeline._pipeline - - expected_stage_order = [ - pipeline_stages_base.PipelineRootStage, - pipeline_stages_iothub_http.IoTHubHTTPTranslationStage, - pipeline_stages_http.HTTPTransportStage, - ] - - # Assert that all PipelineStages are there, and they are in the right order - for i in range(len(expected_stage_order)): - expected_stage = expected_stage_order[i] - assert isinstance(curr_stage, expected_stage) - assert curr_stage.nucleus is pipeline._nucleus - curr_stage = curr_stage.next - - # Assert there are no more additional stages - assert curr_stage is None - - @pytest.mark.it("Runs an InitializePipelineOperation on the pipeline") - def test_sas_auth(self, mocker, pipeline_configuration): - mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - - pipeline = HTTPPipeline(pipeline_configuration) - - op = pipeline._pipeline.run_op.call_args[0][1] - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_base.InitializePipelineOperation) - - @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the InitializePipelineOperation" - ) - def test_sas_auth_op_fail(self, mocker, arbitrary_exception, pipeline_configuration): - old_run_op = pipeline_stages_base.PipelineRootStage._run_op - - def fail_initialize(self, op): - if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - op.complete(error=arbitrary_exception) - else: - old_run_op(self, op) - - mocker.patch.object( - pipeline_stages_base.PipelineRootStage, - "_run_op", - side_effect=fail_initialize, - autospec=True, - ) - - with pytest.raises(arbitrary_exception.__class__) as e_info: - HTTPPipeline(pipeline_configuration) - assert e_info.value is arbitrary_exception - - -@pytest.mark.describe("HTTPPipeline - .invoke_method()") -class TestHTTPPipelineInvokeMethod(object): - @pytest.mark.it("Runs a MethodInvokeOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - cb = mocker.MagicMock() - pipeline.invoke_method( - device_id=fake_device_id, - module_id=fake_module_id, - method_params=mocker.MagicMock(), - callback=cb, - ) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], - pipeline_ops_iothub_http.MethodInvokeOperation, - ) - - @pytest.mark.it( - "Calls the callback with the error if the pipeline_configuration.method_invoke is not True" - ) - def test_op_configuration_fail(self, mocker, pipeline, arbitrary_exception): - pipeline._nucleus.pipeline_configuration.method_invoke = False - cb = mocker.MagicMock() - - pipeline.invoke_method( - device_id=fake_device_id, - module_id=fake_module_id, - method_params=mocker.MagicMock(), - callback=cb, - ) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=mocker.ANY) - - @pytest.mark.it("Passes the correct parameters to the MethodInvokeOperation") - def test_passes_params_to_op(self, pipeline, mocker): - cb = mocker.MagicMock() - mocked_op = mocker.patch.object(pipeline_ops_iothub_http, "MethodInvokeOperation") - fake_method_params = mocker.MagicMock() - pipeline.invoke_method( - device_id=fake_device_id, - module_id=fake_module_id, - method_params=fake_method_params, - callback=cb, - ) - - assert mocked_op.call_args == mocker.call( - callback=mocker.ANY, - method_params=fake_method_params, - target_device_id=fake_device_id, - target_module_id=fake_module_id, - ) - - @pytest.mark.it("Triggers the callback upon successful completion of the MethodInvokeOperation") - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.invoke_method( - device_id=fake_device_id, - module_id=fake_module_id, - method_params=mocker.MagicMock(), - callback=cb, - ) - assert cb.call_count == 0 - - # Trigger op completion - op = pipeline._pipeline.run_op.call_args[0][0] - op.method_response = "__fake_method_response__" - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call( - error=None, invoke_method_response="__fake_method_response__" - ) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the MethodInvokeOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - - pipeline.invoke_method( - device_id=fake_device_id, - module_id=fake_module_id, - method_params=mocker.MagicMock(), - callback=cb, - ) - op = pipeline._pipeline.run_op.call_args[0][0] - - op.complete(error=arbitrary_exception) - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception, invoke_method_response=None) - - -@pytest.mark.describe("HTTPPipeline - .get_storage_info_for_blob()") -class TestHTTPPipelineGetStorageInfo(object): - @pytest.mark.it("Runs a GetStorageInfoOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - pipeline.get_storage_info_for_blob( - blob_name="__fake_blob_name__", callback=mocker.MagicMock() - ) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], - pipeline_ops_iothub_http.GetStorageInfoOperation, - ) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the GetStorageInfoOperation" - ) - def test_op_configuration_fail(self, mocker, pipeline): - pipeline._nucleus.pipeline_configuration.blob_upload = False - cb = mocker.MagicMock() - pipeline.get_storage_info_for_blob(blob_name="__fake_blob_name__", callback=cb) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=mocker.ANY) - - @pytest.mark.it( - "Triggers the callback upon successful completion of the GetStorageInfoOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.get_storage_info_for_blob(blob_name="__fake_blob_name__", callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.storage_info = "__fake_storage_info__" - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None, storage_info="__fake_storage_info__") - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the GetStorageInfoOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.get_storage_info_for_blob(blob_name="__fake_blob_name__", callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception, storage_info=None) - - -@pytest.mark.describe("HTTPPipeline - .notify_blob_upload_status()") -class TestHTTPPipelineNotifyBlobUploadStatus(object): - @pytest.mark.it( - "Runs a NotifyBlobUploadStatusOperation with the provided parameters on the pipeline" - ) - def test_runs_op(self, pipeline, mocker): - pipeline.notify_blob_upload_status( - correlation_id="__fake_correlation_id__", - is_success="__fake_is_success__", - status_code="__fake_status_code__", - status_description="__fake_status_description__", - callback=mocker.MagicMock(), - ) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation) - - @pytest.mark.it( - "Calls the callback with the error if pipeline_configuration.blob_upload is not True" - ) - def test_op_configuration_fail(self, mocker, pipeline): - pipeline._nucleus.pipeline_configuration.blob_upload = False - cb = mocker.MagicMock() - pipeline.notify_blob_upload_status( - correlation_id="__fake_correlation_id__", - is_success="__fake_is_success__", - status_code="__fake_status_code__", - status_description="__fake_status_description__", - callback=cb, - ) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=mocker.ANY) - - @pytest.mark.it( - "Triggers the callback upon successful completion of the NotifyBlobUploadStatusOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.notify_blob_upload_status( - correlation_id="__fake_correlation_id__", - is_success="__fake_is_success__", - status_code="__fake_status_code__", - status_description="__fake_status_description__", - callback=cb, - ) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the NotifyBlobUploadStatusOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.notify_blob_upload_status( - correlation_id="__fake_correlation_id__", - is_success="__fake_is_success__", - status_code="__fake_status_code__", - status_description="__fake_status_description__", - callback=cb, - ) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) diff --git a/tests/unit/iothub/pipeline/test_mqtt_pipeline.py b/tests/unit/iothub/pipeline/test_mqtt_pipeline.py deleted file mode 100644 index 10e1b9b1d..000000000 --- a/tests/unit/iothub/pipeline/test_mqtt_pipeline.py +++ /dev/null @@ -1,1140 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.common.pipeline import ( - pipeline_stages_base, - pipeline_stages_mqtt, - pipeline_ops_base, - pipeline_exceptions, - pipeline_nucleus, -) -from azure.iot.device.iothub.pipeline import ( - config, - pipeline_stages_iothub, - pipeline_stages_iothub_mqtt, - pipeline_ops_iothub, - pipeline_events_iothub, -) -from azure.iot.device.iothub.pipeline import MQTTPipeline -from .conftest import all_features - -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -@pytest.fixture -def pipeline_configuration(mocker): - # NOTE: Consider parametrizing this to serve as both a device and module configuration. - # The reason this isn't currently done is that it's not strictly necessary, but it might be - # more correct and complete to do so. Certainly this must be done if any device/module - # specific logic is added to the code under test. - mock_config = config.IoTHubPipelineConfig( - device_id="my_device", hostname="my.host.name", sastoken=mocker.MagicMock() - ) - mock_config.sastoken.ttl = 1232 # set for compat - mock_config.sastoken.expiry_time = 1232131 # set for compat - return mock_config - - -@pytest.fixture -def pipeline(mocker, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - mocker.patch.object(pipeline._pipeline, "run_op") - return pipeline - - -@pytest.fixture -def twin_patch(): - return {"key": "value"} - - -# automatically mock the transport for all tests in this file. -@pytest.fixture(autouse=True) -def mock_transport(mocker): - return mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True - ) - - -@pytest.mark.describe("MQTTPipeline - Instantiation") -class TestMQTTPipelineInstantiation(object): - @pytest.mark.it("Begins tracking the enabled/disabled status of features") - @pytest.mark.parametrize("feature", all_features) - def test_features(self, pipeline_configuration, feature): - pipeline = MQTTPipeline(pipeline_configuration) - pipeline.feature_enabled[feature] - # No assertion required - if this doesn't raise a KeyError, it is a success - - @pytest.mark.it("Marks all features as disabled") - def test_features_disabled(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - for key in pipeline.feature_enabled: - assert not pipeline.feature_enabled[key] - - @pytest.mark.it("Sets all handlers to an initial value of None") - def test_handlers_set_to_none(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - assert pipeline.on_connected is None - assert pipeline.on_disconnected is None - assert pipeline.on_new_sastoken_required is None - assert pipeline.on_background_exception is None - assert pipeline.on_c2d_message_received is None - assert pipeline.on_input_message_received is None - assert pipeline.on_method_request_received is None - assert pipeline.on_twin_patch_received is None - - @pytest.mark.it("Configures the pipeline to trigger handlers in response to external events") - def test_handlers_configured(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - assert pipeline._pipeline.on_pipeline_event_handler is not None - assert pipeline._pipeline.on_connected_handler is not None - assert pipeline._pipeline.on_disconnected_handler is not None - - @pytest.mark.it("Configures the pipeline with a PipelineNucleus") - def test_pipeline_nucleus(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - - assert isinstance(pipeline._nucleus, pipeline_nucleus.PipelineNucleus) - assert pipeline._nucleus.pipeline_configuration is pipeline_configuration - - @pytest.mark.it("Configures the pipeline with a series of PipelineStages") - def test_pipeline_stages(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - curr_stage = pipeline._pipeline - - expected_stage_order = [ - pipeline_stages_base.PipelineRootStage, - pipeline_stages_base.SasTokenStage, - pipeline_stages_iothub.EnsureDesiredPropertiesStage, - pipeline_stages_iothub.TwinRequestResponseStage, - pipeline_stages_base.CoordinateRequestAndResponseStage, - pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage, - pipeline_stages_base.AutoConnectStage, - pipeline_stages_base.ConnectionStateStage, - pipeline_stages_base.RetryStage, - pipeline_stages_base.OpTimeoutStage, - pipeline_stages_mqtt.MQTTTransportStage, - ] - - # Assert that all PipelineStages are there, and they are in the right order - for i in range(len(expected_stage_order)): - expected_stage = expected_stage_order[i] - assert isinstance(curr_stage, expected_stage) - assert curr_stage.nucleus is pipeline._nucleus - curr_stage = curr_stage.next - - # Assert there are no more additional stages - assert curr_stage is None - - @pytest.mark.it("Runs an InitializePipelineOperation on the pipeline") - def test_init_pipeline(self, mocker, pipeline_configuration): - mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - - pipeline = MQTTPipeline(pipeline_configuration) - - op = pipeline._pipeline.run_op.call_args[0][1] - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_base.InitializePipelineOperation) - - @pytest.mark.it( - "Sets a flag to indicate the pipeline is 'running' upon successful completion of the InitializePipelineOperation" - ) - def test_running(self, mocker, pipeline_configuration): - # Because this is an init test, there isn't really a way to check that it only occurs after - # the op. The reason is because this is the object's init, the object doesn't actually - # exist until the entire method has completed, so there's no reference you can check prior - # to method completion. - pipeline = MQTTPipeline(pipeline_configuration) - assert pipeline._running - - @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the InitializePipelineOperation" - ) - def test_init_pipeline_fail(self, mocker, arbitrary_exception, pipeline_configuration): - old_run_op = pipeline_stages_base.PipelineRootStage._run_op - - def fail_initialize(self, op): - if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - op.complete(error=arbitrary_exception) - else: - old_run_op(self, op) - - mocker.patch.object( - pipeline_stages_base.PipelineRootStage, - "_run_op", - side_effect=fail_initialize, - autospec=True, - ) - - with pytest.raises(arbitrary_exception.__class__) as e_info: - MQTTPipeline(pipeline_configuration) - assert e_info.value is arbitrary_exception - - -@pytest.mark.describe("MQTTPipeline - .shutdown()") -class TestMQTTPipelineShutdown(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.shutdown(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a ShutdownPipelineOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - cb = mocker.MagicMock() - pipeline.shutdown(callback=cb) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.ShutdownPipelineOperation - ) - - @pytest.mark.it( - "Triggers the callback upon successful completion of the ShutdownPipelineOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.shutdown(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the ShutdownPipelineOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - - pipeline.shutdown(callback=cb) - op = pipeline._pipeline.run_op.call_args[0][0] - - op.complete(error=arbitrary_exception) - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - @pytest.mark.it( - "Sets a flag to indicate the pipeline is no longer running only upon successful completion of the ShutdownPipelineOperation" - ) - def test_set_not_running(self, mocker, pipeline, arbitrary_exception): - # Pipeline is running - assert pipeline._running - - # Begin operation (we will fail this one) - cb = mocker.MagicMock() - pipeline.shutdown(callback=cb) - assert cb.call_count == 0 - - # Pipeline is still running - assert pipeline._running - - # Trigger op completion (failure) - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - # Pipeline is still running - assert pipeline._running - - # Try operation again (we will make this one succeed) - cb.reset_mock() - pipeline.shutdown(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion (successful) - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - # Pipeline is no longer running - assert not pipeline._running - - -@pytest.mark.describe("MQTTPipeline - .connect()") -class TestMQTTPipelineConnect(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.connect(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a ConnectOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - cb = mocker.MagicMock() - pipeline.connect(callback=cb) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.ConnectOperation - ) - - @pytest.mark.it("Triggers the callback upon successful completion of the ConnectOperation") - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.connect(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the ConnectOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - - pipeline.connect(callback=cb) - op = pipeline._pipeline.run_op.call_args[0][0] - - op.complete(error=arbitrary_exception) - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .disconnect()") -class TestMQTTPipelineDisconnect(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.disconnect(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a DisconnectOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - pipeline.disconnect(callback=mocker.MagicMock()) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.DisconnectOperation - ) - - @pytest.mark.it("Triggers the callback upon successful completion of the DisconnectOperation") - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.disconnect(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the DisconnectOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.disconnect(callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .reauthorize_connection()") -class TestMQTTPipelineReauthorizeConnection(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.reauthorize_connection(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a ReauthorizeConnectionOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - pipeline.reauthorize_connection(callback=mocker.MagicMock()) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], - pipeline_ops_base.ReauthorizeConnectionOperation, - ) - - @pytest.mark.it( - "Triggers the callback upon successful completion of the ReauthorizeConnectionOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.reauthorize_connection(callback=cb) - assert cb.call_count == 0 - - # Trigger oop completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the ReauthorizeConnectionOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.reauthorize_connection(callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .send_message()") -class TestMQTTPipelineSendD2CMessage(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, message, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.send_message(message, callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a SendD2CMessageOperation with the provided message on the pipeline") - def test_runs_op(self, pipeline, message, mocker): - pipeline.send_message(message, callback=mocker.MagicMock()) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SendD2CMessageOperation) - assert op.message == message - - @pytest.mark.it( - "Triggers the callback upon successful completion of the SendD2CMessageOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline, message): - cb = mocker.MagicMock() - - # Begin operation - pipeline.send_message(message, callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the SendD2CMessageOperation" - ) - def test_op_fail(self, mocker, pipeline, message, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.send_message(message, callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .send_output_message()") -class TestMQTTPipelineSendOutputMessage(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, message, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.send_output_message(message, callback=mocker.MagicMock()) - - @pytest.fixture - def message(self, message): - """Modify message fixture to have an output""" - message.output_name = "some output" - return message - - @pytest.mark.it("Runs a SendOutputMessageOperation with the provided Message on the pipeline") - def test_runs_op(self, pipeline, message, mocker): - pipeline.send_output_message(message, callback=mocker.MagicMock()) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SendOutputMessageOperation) - assert op.message == message - - @pytest.mark.it( - "Triggers the callback upon successful completion of the SendOutputMessageOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline, message): - cb = mocker.MagicMock() - - # Begin operation - pipeline.send_output_message(message, callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the SendOutputMessageOperation" - ) - def test_op_fail(self, mocker, pipeline, message, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.send_output_message(message, callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .send_method_response()") -class TestMQTTPipelineSendMethodResponse(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, method_response, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.send_method_response(method_response, callback=mocker.MagicMock()) - - @pytest.mark.it( - "Runs a SendMethodResponseOperation with the provided MethodResponse on the pipeline" - ) - def test_runs_op(self, pipeline, method_response, mocker): - pipeline.send_method_response(method_response, callback=mocker.MagicMock()) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SendMethodResponseOperation) - assert op.method_response == method_response - - @pytest.mark.it( - "Triggers the callback upon successful completion of the SendMethodResponseOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline, method_response): - cb = mocker.MagicMock() - - # Begin operation - pipeline.send_method_response(method_response, callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the SendMethodResponseOperation" - ) - def test_op_fail(self, mocker, pipeline, method_response, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.send_method_response(method_response, callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .get_twin()") -class TestMQTTPipelineGetTwin(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.get_twin(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a GetTwinOperation on the pipeline") - def test_runs_op(self, mocker, pipeline): - cb = mocker.MagicMock() - pipeline.get_twin(callback=cb) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_iothub.GetTwinOperation - ) - - @pytest.mark.it( - "Triggers the provided callback upon successful completion of the GetTwinOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.get_twin(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(twin=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the GetTwinOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.get_twin(callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(twin=None, error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .patch_twin_reported_properties()") -class TestMQTTPipelinePatchTwinReportedProperties(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, twin_patch, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.patch_twin_reported_properties(twin_patch, callback=mocker.MagicMock()) - - @pytest.mark.it( - "Runs a PatchTwinReportedPropertiesOperation with the provided twin patch on the pipeline" - ) - def test_runs_op(self, pipeline, twin_patch, mocker): - pipeline.patch_twin_reported_properties(twin_patch, callback=mocker.MagicMock()) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.PatchTwinReportedPropertiesOperation) - assert op.patch == twin_patch - - @pytest.mark.it( - "Triggers the callback upon successful completion of the PatchTwinReportedPropertiesOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline, twin_patch): - cb = mocker.MagicMock() - - # Begin operation - pipeline.patch_twin_reported_properties(twin_patch, callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the PatchTwinReportedPropertiesOperation" - ) - def test_op_fail(self, mocker, pipeline, twin_patch, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.patch_twin_reported_properties(twin_patch, callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .enable_feature()") -class TestMQTTPipelineEnableFeature(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - @pytest.mark.parametrize("feature", all_features) - def test_not_running(self, mocker, feature, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.enable_feature(feature, callback=mocker.MagicMock()) - - @pytest.mark.it("Raises ValueError if the feature_name is invalid") - def test_invalid_feature_name(self, pipeline, mocker): - bad_feature = "not-a-feature-name" - assert bad_feature not in pipeline.feature_enabled - with pytest.raises(ValueError): - pipeline.enable_feature(bad_feature, callback=mocker.MagicMock()) - assert bad_feature not in pipeline.feature_enabled - - # TODO: what about features that are already disabled? - - @pytest.mark.it("Runs a EnableFeatureOperation with the provided feature_name on the pipeline") - @pytest.mark.parametrize("feature", all_features) - def test_runs_op(self, pipeline, feature, mocker): - pipeline.enable_feature(feature, callback=mocker.MagicMock()) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_base.EnableFeatureOperation) - assert op.feature_name == feature - - @pytest.mark.it("Does not mark the feature as enabled before the callback is complete") - @pytest.mark.parametrize("feature", all_features) - def test_mark_feature_not_enabled(self, pipeline, feature, mocker): - assert not pipeline.feature_enabled[feature] - callback = mocker.MagicMock() - pipeline.enable_feature(feature, callback=callback) - - assert callback.call_count == 0 - assert not pipeline.feature_enabled[feature] - - @pytest.mark.it("Does not mark the feature as enabled if the EnableFeatureOperation fails") - @pytest.mark.parametrize("feature", all_features) - def test_mark_feature_not_enabled_on_failure( - self, pipeline, feature, mocker, arbitrary_exception - ): - assert not pipeline.feature_enabled[feature] - callback = mocker.MagicMock() - pipeline.enable_feature(feature, callback=callback) - - op = pipeline._pipeline.run_op.call_args[0][0] - assert isinstance(op, pipeline_ops_base.EnableFeatureOperation) - op.complete(arbitrary_exception) - - assert callback.call_count == 1 - assert not pipeline.feature_enabled[feature] - - @pytest.mark.it("Marks the feature as enabled if the EnableFeatureOperation succeeds") - @pytest.mark.parametrize("feature", all_features) - def test_mark_feature_enabled_on_success(self, pipeline, feature, mocker): - assert not pipeline.feature_enabled[feature] - callback = mocker.MagicMock() - pipeline.enable_feature(feature, callback=callback) - - op = pipeline._pipeline.run_op.call_args[0][0] - assert isinstance(op, pipeline_ops_base.EnableFeatureOperation) - op.complete() - - assert callback.call_count == 1 - assert pipeline.feature_enabled[feature] - - @pytest.mark.it( - "Triggers the callback upon successful completion of the EnableFeatureOperation" - ) - @pytest.mark.parametrize("feature", all_features) - def test_op_success_with_callback(self, mocker, pipeline, feature): - cb = mocker.MagicMock() - - # Begin operation - pipeline.enable_feature(feature, callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the EnableFeatureOperation" - ) - @pytest.mark.parametrize("feature", all_features) - def test_op_fail(self, mocker, pipeline, feature, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.enable_feature(feature, callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .disable_feature()") -class TestMQTTPipelineDisableFeature(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - @pytest.mark.parametrize("feature", all_features) - def test_not_running(self, mocker, feature, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.disable_feature(feature, callback=mocker.MagicMock()) - - @pytest.mark.it("Raises ValueError if the feature_name is invalid") - def test_invalid_feature_name(self, pipeline, mocker): - bad_feature = "not-a-feature-name" - assert bad_feature not in pipeline.feature_enabled - with pytest.raises(ValueError): - pipeline.disable_feature(bad_feature, callback=mocker.MagicMock()) - assert bad_feature not in pipeline.feature_enabled - - # TODO: what about features that are already disabled? - - @pytest.mark.it("Runs a DisableFeatureOperation with the provided feature_name on the pipeline") - @pytest.mark.parametrize("feature", all_features) - def test_runs_op(self, pipeline, feature, mocker): - pipeline.disable_feature(feature, callback=mocker.MagicMock()) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_base.DisableFeatureOperation) - assert op.feature_name == feature - - @pytest.mark.it("Does not mark the feature as disabled before the callback is complete") - @pytest.mark.parametrize("feature", all_features) - def test_mark_feature_not_enabled(self, pipeline, feature, mocker): - # feature is already enabled - pipeline.feature_enabled[feature] = True - assert pipeline.feature_enabled[feature] - - # start call to disable feature - callback = mocker.MagicMock() - pipeline.disable_feature(feature, callback=callback) - - # feature is still enabled (because callback has not been completed yet) - assert callback.call_count == 0 - assert pipeline.feature_enabled[feature] - - @pytest.mark.it("Marks the feature as disabled if the DisableFeatureOperation succeeds") - @pytest.mark.parametrize("feature", all_features) - def test_mark_feature_enabled_on_success(self, pipeline, feature, mocker): - # feature is already enabled - pipeline.feature_enabled[feature] = True - assert pipeline.feature_enabled[feature] - - # try to disable the feature (and succeed) - callback = mocker.MagicMock() - pipeline.disable_feature(feature, callback=callback) - op = pipeline._pipeline.run_op.call_args[0][0] - assert isinstance(op, pipeline_ops_base.DisableFeatureOperation) - op.complete() - - assert callback.call_count == 1 - assert not pipeline.feature_enabled[feature] - - @pytest.mark.it("Marks the feature as disabled even if the DisableFeatureOperation fails") - @pytest.mark.parametrize("feature", all_features) - def test_mark_feature_not_enabled_on_failure( - self, pipeline, feature, mocker, arbitrary_exception - ): - # feature is already enabled - pipeline.feature_enabled[feature] = True - assert pipeline.feature_enabled[feature] - - # tyr to disable the feature (but fail) - callback = mocker.MagicMock() - pipeline.disable_feature(feature, callback=callback) - op = pipeline._pipeline.run_op.call_args[0][0] - assert isinstance(op, pipeline_ops_base.DisableFeatureOperation) - op.complete(arbitrary_exception) - - # Feature was STILL disabled - assert callback.call_count == 1 - assert not pipeline.feature_enabled[feature] - - @pytest.mark.it( - "Triggers the callback upon successful completion of the DisableFeatureOperation" - ) - @pytest.mark.parametrize("feature", all_features) - def test_op_success_with_callback(self, mocker, pipeline, feature): - cb = mocker.MagicMock() - - # Begin operation - pipeline.disable_feature(feature, callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the DisableFeatureOperation" - ) - @pytest.mark.parametrize("feature", all_features) - def _est_op_fail(self, mocker, pipeline, feature, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.disable_feature(feature, callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: Connected") -class TestMQTTPipelineOCCURRENCEConnect(object): - @pytest.mark.it("Triggers the 'on_connected' handler") - def test_with_handler(self, mocker, pipeline): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_connected = mock_handler - assert mock_handler.call_count == 0 - - # Trigger the connect - pipeline._pipeline.on_connected_handler() - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call() - - @pytest.mark.it("Does nothing if the 'on_connected' handler is not set") - def test_without_handler(self, pipeline): - pipeline._pipeline.on_connected_handler() - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: Disconnected") -class TestMQTTPipelineOCCURRENCEDisconnect(object): - @pytest.mark.it("Triggers the 'on_disconnected' handler") - def test_with_handler(self, mocker, pipeline): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_disconnected = mock_handler - assert mock_handler.call_count == 0 - - # Trigger the disconnect - pipeline._pipeline.on_disconnected_handler() - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call() - - @pytest.mark.it("Does nothing if the 'on_disconnected' handler is not set") - def test_without_handler(self, pipeline): - pipeline._pipeline.on_disconnected_handler() - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: New Sastoken Required") -class TestMQTTPipelineOCCURRENCENewSastokenRequired(object): - @pytest.mark.it("Triggers the 'on_new_sastoken_required' handler") - def test_with_handler(self, mocker, pipeline): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_new_sastoken_required = mock_handler - assert mock_handler.call_count == 0 - - # Trigger the event - pipeline._pipeline.on_new_sastoken_required_handler() - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call() - - @pytest.mark.it("Does nothing if the 'on_new_sastoken_required' handler is not set") - def test_without_handler(self, pipeline): - pipeline._pipeline.on_new_sastoken_required_handler() - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: Background Exception") -class TestMQTTPipelineOCCURRENCEBackgroundException(object): - @pytest.mark.it("Triggers the 'on_background_exception' handler") - def test_with_handler(self, mocker, pipeline, arbitrary_exception): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_background_exception = mock_handler - assert mock_handler.call_count == 0 - - # Trigger the background exception - pipeline._pipeline.on_background_exception_handler(arbitrary_exception) - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(arbitrary_exception) - - @pytest.mark.it("Does nothing if the 'on_background_exception' handler is not set") - def test_without_handler(self, pipeline, arbitrary_exception): - pipeline._pipeline.on_background_exception_handler(arbitrary_exception) - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: C2D Message Received") -class TestMQTTPipelineOCCURRENCEReceiveC2DMessage(object): - @pytest.mark.it( - "Triggers the 'on_c2d_message_received' handler, passing the received message as an argument" - ) - def test_with_handler(self, mocker, pipeline, message): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_c2d_message_received = mock_handler - assert mock_handler.call_count == 0 - - # Create the event - c2d_event = pipeline_events_iothub.C2DMessageEvent(message) - - # Trigger the event - pipeline._pipeline.on_pipeline_event_handler(c2d_event) - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(message) - - @pytest.mark.it("Drops the message if the 'on_c2d_message_received' handler is not set") - def test_no_handler(self, pipeline, message): - c2d_event = pipeline_events_iothub.C2DMessageEvent(message) - pipeline._pipeline.on_pipeline_event_handler(c2d_event) - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: Input Message Received") -class TestMQTTPipelineOCCURRENCEReceiveInputMessage(object): - @pytest.mark.it( - "Triggers the 'on_input_message_received' handler, passing the received message as an argument" - ) - def test_with_handler(self, mocker, pipeline, message): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_input_message_received = mock_handler - assert mock_handler.call_count == 0 - - # Create the event - input_name = "some_input" - message.input_name = input_name - input_message_event = pipeline_events_iothub.InputMessageEvent(message) - - # Trigger the event - pipeline._pipeline.on_pipeline_event_handler(input_message_event) - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(message) - - @pytest.mark.it("Drops the message if the 'on_input_message_received' handler is not set") - def test_no_handler(self, pipeline, message): - input_name = "some_input" - message.input_name = input_name - input_message_event = pipeline_events_iothub.InputMessageEvent(message) - pipeline._pipeline.on_pipeline_event_handler(input_message_event) - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: Method Request Received") -class TestMQTTPipelineOCCURRENCEReceiveMethodRequest(object): - @pytest.mark.it( - "Triggers the 'on_method_request_received' handler, passing the received method request as an argument" - ) - def test_with_handler(self, mocker, pipeline, method_request): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_method_request_received = mock_handler - assert mock_handler.call_count == 0 - - # Create the event - method_request_event = pipeline_events_iothub.MethodRequestEvent(method_request) - - # Trigger the event - pipeline._pipeline.on_pipeline_event_handler(method_request_event) - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(method_request) - - @pytest.mark.it( - "Drops the method request if the 'on_method_request_received' handler is not set" - ) - def test_no_handler(self, pipeline, method_request): - method_request_event = pipeline_events_iothub.MethodRequestEvent(method_request) - pipeline._pipeline.on_pipeline_event_handler(method_request_event) - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - OCCURRENCE: Twin Desired Properties Patch Received") -class TestMQTTPipelineOCCURRENCEReceiveDesiredPropertiesPatch(object): - @pytest.mark.it( - "Triggers the 'on_twin_patch_received' handler, passing the received twin patch as an argument" - ) - def test_with_handler(self, mocker, pipeline, twin_patch): - # Set the handler - mock_handler = mocker.MagicMock() - pipeline.on_twin_patch_received = mock_handler - assert mock_handler.call_count == 0 - - # Create the event - twin_patch_event = pipeline_events_iothub.TwinDesiredPropertiesPatchEvent(twin_patch) - - # Trigger the event - pipeline._pipeline.on_pipeline_event_handler(twin_patch_event) - - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(twin_patch) - - @pytest.mark.it("Drops the twin patch if the 'on_twin_patch_received' handler is not set") - def test_no_handler(self, pipeline, twin_patch): - twin_patch_event = pipeline_events_iothub.TwinDesiredPropertiesPatchEvent(twin_patch) - pipeline._pipeline.on_pipeline_event_handler(twin_patch_event) - - # No assertions required - not throwing an exception means the test passed - - -@pytest.mark.describe("MQTTPipeline - PROPERTY .pipeline_configuration") -class TestMQTTPipelinePROPERTYPipelineConfiguration(object): - @pytest.mark.it("Value of the object cannot be changed") - def test_read_only(self, pipeline): - with pytest.raises(AttributeError): - pipeline.pipeline_configuration = 12 - - @pytest.mark.it("Values ON the object CAN be changed") - def test_update_values_on_read_only_object(self, pipeline): - assert pipeline.pipeline_configuration.sastoken is not None - pipeline.pipeline_configuration.sastoken = None - assert pipeline.pipeline_configuration.sastoken is None - - @pytest.mark.it("Reflects the value of the PipelineNucleus attribute of the same name") - def test_reflects_pipeline_attribute(self, pipeline): - assert pipeline.pipeline_configuration is pipeline._nucleus.pipeline_configuration - - -@pytest.mark.describe("MQTTPipeline - PROPERTY .connected") -class TestMQTTPipelinePROPERTYConnected(object): - @pytest.mark.it("Cannot be changed") - def test_read_only(self, pipeline): - with pytest.raises(AttributeError): - pipeline.connected = not pipeline.connected - - @pytest.mark.it("Reflects the value of the PipelineNucleus attribute of the same name") - def test_reflects_pipeline_attribute(self, pipeline, pipeline_connected_mock): - # Need to set indirectly via mock due to nucleus attribute being read-only - type(pipeline._nucleus).connected = pipeline_connected_mock - pipeline_connected_mock.return_value = True - assert pipeline._nucleus.connected - assert pipeline.connected - # Again, must be set indirectly - pipeline_connected_mock.return_value = False - assert not pipeline._nucleus.connected - assert not pipeline.connected diff --git a/tests/unit/iothub/pipeline/test_mqtt_topic_iothub.py b/tests/unit/iothub/pipeline/test_mqtt_topic_iothub.py deleted file mode 100644 index 8e633be51..000000000 --- a/tests/unit/iothub/pipeline/test_mqtt_topic_iothub.py +++ /dev/null @@ -1,1433 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -import datetime -from azure.iot.device.iothub.pipeline import mqtt_topic_iothub -from azure.iot.device import Message - -logging.basicConfig(level=logging.DEBUG) - -# NOTE: All tests (that require it) are parametrized with multiple values for URL encoding. -# This is to show that the URL encoding is done correctly - not all URL encoding encodes -# the same way. -# -# For URL encoding, we must always test the ' ' and '/' characters specifically, in addition -# to a generic URL encoding value (e.g. $, #, etc.) -# -# For URL decoding, we must always test the '+' character specifically, in addition to -# a generic URL encoded value (e.g. %24, %23, etc.) -# -# Please also always test that provided values are converted to strings in order to ensure -# that they can be URL encoded without error. -# -# PLEASE DO THESE TESTS FOR EVEN CASES WHERE THOSE CHARACTERS SHOULD NOT OCCUR, FOR SAFETY. - - -@pytest.mark.describe(".get_c2d_topic_for_subscribe()") -class TestGetC2DTopicForSubscribe(object): - @pytest.mark.it("Returns the topic for subscribing to C2D messages from IoTHub") - def test_returns_topic(self): - device_id = "my_device" - expected_topic = "devices/my_device/messages/devicebound/#" - topic = mqtt_topic_iothub.get_c2d_topic_for_subscribe(device_id) - assert topic == expected_topic - - # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have - # to follow that and not do URL encoding for safety. As a result, some of the values used in - # this test would actually be invalid in production due to character restrictions on the Hub - # that exist to prevent Hub from breaking due to a lack of URL decoding. - # If Hub does begin to support robust URL encoding for safety, this test can easily be switched - # to show that URL encoding DOES work. - @pytest.mark.it("Does NOT URL encode the device_id when generating the topic") - @pytest.mark.parametrize( - "device_id, expected_topic", - [ - pytest.param( - "my$device", "devices/my$device/messages/devicebound/#", id="id contains '$'" - ), - pytest.param( - "my device", "devices/my device/messages/devicebound/#", id="id contains ' '" - ), - pytest.param( - "my/device", "devices/my/device/messages/devicebound/#", id="id contains '/'" - ), - ], - ) - def test_url_encoding(self, device_id, expected_topic): - topic = mqtt_topic_iothub.get_c2d_topic_for_subscribe(device_id) - assert topic == expected_topic - - @pytest.mark.it("Converts the device_id to string when generating the topic") - def test_str_conversion(self): - device_id = 2000 - expected_topic = "devices/2000/messages/devicebound/#" - topic = mqtt_topic_iothub.get_c2d_topic_for_subscribe(device_id) - assert topic == expected_topic - - -@pytest.mark.describe(".get_input_topic_for_subscribe()") -class TestGetInputTopicForSubscribe(object): - @pytest.mark.it("Returns the topic for subscribing to Input messages from IoTHub") - def test_returns_topic(self): - device_id = "my_device" - module_id = "my_module" - expected_topic = "devices/my_device/modules/my_module/inputs/#" - topic = mqtt_topic_iothub.get_input_topic_for_subscribe(device_id, module_id) - assert topic == expected_topic - - # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have - # to follow that and not do URL encoding for safety. As a result, some of the values used in - # this test would actually be invalid in production due to character restrictions on the Hub - # that exist to prevent Hub from breaking due to a lack of URL decoding. - # If Hub does begin to support robust URL encoding for safety, this test can easily be switched - # to show that URL encoding DOES work. - @pytest.mark.it("URL encodes the device_id and module_id when generating the topic") - @pytest.mark.parametrize( - "device_id, module_id, expected_topic", - [ - pytest.param( - "my$device", - "my$module", - "devices/my$device/modules/my$module/inputs/#", - id="ids contain '$'", - ), - pytest.param( - "my device", - "my module", - "devices/my device/modules/my module/inputs/#", - id="ids contain ' '", - ), - pytest.param( - "my/device", - "my/module", - "devices/my/device/modules/my/module/inputs/#", - id="ids contain '/'", - ), - ], - ) - def test_url_encoding(self, device_id, module_id, expected_topic): - topic = mqtt_topic_iothub.get_input_topic_for_subscribe(device_id, module_id) - assert topic == expected_topic - - @pytest.mark.it("Converts the device_id and module_id to string when generating the topic") - def test_str_conversion(self): - device_id = 2000 - module_id = 4000 - expected_topic = "devices/2000/modules/4000/inputs/#" - topic = mqtt_topic_iothub.get_input_topic_for_subscribe(device_id, module_id) - assert topic == expected_topic - - -@pytest.mark.describe(".get_method_topic_for_subscribe()") -class TestGetMethodTopicForSubscribe(object): - @pytest.mark.it("Returns the topic for subscribing to methods from IoTHub") - def test_returns_topic(self): - topic = mqtt_topic_iothub.get_method_topic_for_subscribe() - assert topic == "$iothub/methods/POST/#" - - -@pytest.mark.describe("get_twin_response_topic_for_subscribe()") -class TestGetTwinResponseTopicForSubscribe(object): - @pytest.mark.it("Returns the topic for subscribing to twin response from IoTHub") - def test_returns_topic(self): - topic = mqtt_topic_iothub.get_twin_response_topic_for_subscribe() - assert topic == "$iothub/twin/res/#" - - -@pytest.mark.describe("get_twin_patch_topic_for_subscribe()") -class TestGetTwinPatchTopicForSubscribe(object): - @pytest.mark.it("Returns the topic for subscribing to twin patches from IoTHub") - def test_returns_topic(self): - topic = mqtt_topic_iothub.get_twin_patch_topic_for_subscribe() - assert topic == "$iothub/twin/PATCH/properties/desired/#" - - -@pytest.mark.describe(".get_telemetry_topic_for_publish()") -class TestGetTelemetryTopicForPublish(object): - @pytest.mark.it("Returns the topic for sending telemetry to IoTHub") - @pytest.mark.parametrize( - "device_id, module_id, expected_topic", - [ - pytest.param("my_device", None, "devices/my_device/messages/events/", id="Device"), - pytest.param( - "my_device", - "my_module", - "devices/my_device/modules/my_module/messages/events/", - id="Module", - ), - ], - ) - def test_returns_topic(self, device_id, module_id, expected_topic): - topic = mqtt_topic_iothub.get_telemetry_topic_for_publish(device_id, module_id) - assert topic == expected_topic - - # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have - # to follow that and not do URL encoding for safety. As a result, some of the values used in - # this test would actually be invalid in production due to character restrictions on the Hub - # that exist to prevent Hub from breaking due to a lack of URL decoding. - # If Hub does begin to support robust URL encoding for safety, this test can easily be switched - # to show that URL encoding DOES work. - @pytest.mark.it("URL encodes the device_id and module_id when generating the topic") - @pytest.mark.parametrize( - "device_id, module_id, expected_topic", - [ - pytest.param( - "my$device", - None, - "devices/my$device/messages/events/", - id="Device, id contains '$'", - ), - pytest.param( - "my device", - None, - "devices/my device/messages/events/", - id="Device, id contains ' '", - ), - pytest.param( - "my/device", - None, - "devices/my/device/messages/events/", - id="Device, id contains '/'", - ), - pytest.param( - "my$device", - "my$module", - "devices/my$device/modules/my$module/messages/events/", - id="Module, ids contain '$'", - ), - pytest.param( - "my device", - "my module", - "devices/my device/modules/my module/messages/events/", - id="Module, ids contain ' '", - ), - pytest.param( - "my/device", - "my/module", - "devices/my/device/modules/my/module/messages/events/", - id="Module, ids contain '/'", - ), - ], - ) - def test_url_encoding(self, device_id, module_id, expected_topic): - topic = mqtt_topic_iothub.get_telemetry_topic_for_publish(device_id, module_id) - assert topic == expected_topic - - @pytest.mark.it("Converts the device_id and module_id to string when generating the topic") - @pytest.mark.parametrize( - "device_id, module_id, expected_topic", - [ - pytest.param(2000, None, "devices/2000/messages/events/", id="Device"), - pytest.param(2000, 4000, "devices/2000/modules/4000/messages/events/", id="Module"), - ], - ) - def test_str_conversion(self, device_id, module_id, expected_topic): - topic = mqtt_topic_iothub.get_telemetry_topic_for_publish(device_id, module_id) - assert topic == expected_topic - - -@pytest.mark.describe(".get_method_topic_for_publish()") -class TestGetMethodTopicForPublish(object): - @pytest.mark.it("Returns the topic for sending a method response to IoTHub") - @pytest.mark.parametrize( - "request_id, status, expected_topic", - [ - pytest.param("1", "200", "$iothub/methods/res/200/?$rid=1", id="Successful result"), - pytest.param( - "475764", "500", "$iothub/methods/res/500/?$rid=475764", id="Failure result" - ), - ], - ) - def test_returns_topic(self, request_id, status, expected_topic): - topic = mqtt_topic_iothub.get_method_topic_for_publish(request_id, status) - assert topic == expected_topic - - @pytest.mark.it("URL encodes provided values when generating the topic") - @pytest.mark.parametrize( - "request_id, status, expected_topic", - [ - pytest.param( - "invalid#request?id", - "invalid$status", - "$iothub/methods/res/invalid%24status/?$rid=invalid%23request%3Fid", - id="Standard URL Encoding", - ), - pytest.param( - "invalid request id", - "invalid status", - "$iothub/methods/res/invalid%20status/?$rid=invalid%20request%20id", - id="URL Encoding of ' ' character", - ), - pytest.param( - "invalid/request/id", - "invalid/status", - "$iothub/methods/res/invalid%2Fstatus/?$rid=invalid%2Frequest%2Fid", - id="URL Encoding of '/' character", - ), - ], - ) - def test_url_encoding(self, request_id, status, expected_topic): - topic = mqtt_topic_iothub.get_method_topic_for_publish(request_id, status) - assert topic == expected_topic - - @pytest.mark.it("Converts the provided values to strings when generating the topic") - def test_str_conversion(self): - request_id = 1 - status = 200 - expected_topic = "$iothub/methods/res/200/?$rid=1" - topic = mqtt_topic_iothub.get_method_topic_for_publish(request_id, status) - assert topic == expected_topic - - -@pytest.mark.describe(".get_twin_topic_for_publish()") -class TestGetTwinTopicForPublish(object): - @pytest.mark.it("Returns topic for sending a twin request to IoTHub") - @pytest.mark.parametrize( - "method, resource_location, request_id, expected_topic", - [ - # Get Twin - pytest.param( - "GET", - "/", - "3226c2f7-3d30-425c-b83b-0c34335f8220", - "$iothub/twin/GET/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220", - id="Get Twin", - ), - # Patch Twin - pytest.param( - "POST", - "/properties/reported/", - "5002b415-af16-47e9-b89c-8680e01b502f", - "$iothub/twin/POST/properties/reported/?$rid=5002b415-af16-47e9-b89c-8680e01b502f", - id="Patch Twin", - ), - ], - ) - def test_returns_topic(self, method, resource_location, request_id, expected_topic): - topic = mqtt_topic_iothub.get_twin_topic_for_publish(method, resource_location, request_id) - assert topic == expected_topic - - @pytest.mark.it("URL encodes 'request_id' parameter") - @pytest.mark.parametrize( - "method, resource_location, request_id, expected_topic", - [ - pytest.param( - "GET", - "/", - "invalid$request?id", - "$iothub/twin/GET/?$rid=invalid%24request%3Fid", - id="Get Twin, Standard URL Encoding", - ), - pytest.param( - "GET", - "/", - "invalid request id", - "$iothub/twin/GET/?$rid=invalid%20request%20id", - id="Get Twin, URL Encoding of ' ' character", - ), - pytest.param( - "GET", - "/", - "invalid/request/id", - "$iothub/twin/GET/?$rid=invalid%2Frequest%2Fid", - id="Get Twin, URL Encoding of '/' character", - ), - pytest.param( - "POST", - "/properties/reported/", - "invalid$request?id", - "$iothub/twin/POST/properties/reported/?$rid=invalid%24request%3Fid", - id="Patch Twin, Standard URL Encoding", - ), - pytest.param( - "POST", - "/properties/reported/", - "invalid request id", - "$iothub/twin/POST/properties/reported/?$rid=invalid%20request%20id", - id="Patch Twin, URL Encoding of ' ' character", - ), - pytest.param( - "POST", - "/properties/reported/", - "invalid/request/id", - "$iothub/twin/POST/properties/reported/?$rid=invalid%2Frequest%2Fid", - id="Patch Twin, URL Encoding of '/' character", - ), - ], - ) - def test_url_encoding(self, method, resource_location, request_id, expected_topic): - topic = mqtt_topic_iothub.get_twin_topic_for_publish(method, resource_location, request_id) - assert topic == expected_topic - - @pytest.mark.it("Converts 'request_id' parameter to string when generating the topic") - @pytest.mark.parametrize( - "method, resource_location, request_id, expected_topic", - [ - # Get Twin - pytest.param("GET", "/", 4000, "$iothub/twin/GET/?$rid=4000", id="Get Twin"), - # Patch Twin - pytest.param( - "POST", - "/properties/reported/", - 2000, - "$iothub/twin/POST/properties/reported/?$rid=2000", - id="Patch Twin", - ), - ], - ) - def test_str_conversion(self, method, resource_location, request_id, expected_topic): - topic = mqtt_topic_iothub.get_twin_topic_for_publish(method, resource_location, request_id) - assert topic == expected_topic - - -@pytest.mark.describe(".is_c2d_topic()") -class TestIsC2DTopic(object): - @pytest.mark.it( - "Returns True if the provided topic is a C2D topic and matches the provided device id" - ) - def test_is_c2d_topic(self): - topic = "devices/fake_device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound" - device_id = "fake_device" - assert mqtt_topic_iothub.is_c2d_topic(topic, device_id) - - # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have - # to follow that and not do URL encoding for safety. As a result, some of the values used in - # this test would actually be invalid in production due to character restrictions on the Hub - # that exist to prevent Hub from breaking due to a lack of URL decoding. - # If Hub does begin to support robust URL encoding for safety, this test can easily be switched - # to show that URL encoding DOES work. - @pytest.mark.it("Does NOT URL encode the device id when matching to the topic") - @pytest.mark.parametrize( - "topic, device_id", - [ - pytest.param( - "devices/fake?device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake?device%2Fmessages%2Fdevicebound", - "fake?device", - id="Standard URL encoding required for device_id", - ), - pytest.param( - "devices/fake device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake device%2Fmessages%2Fdevicebound", - "fake device", - id="URL encoding of ' ' character required for device_id", - ), - # Note that this topic string is completely broken, even beyond the fact that device id's can't have a '/' in them. - # A device id with a '/' would not be possible to decode correctly, because the '/' in the device name encoded in the - # system properties would cause the system properties to not be able to be decoded correctly. But, like many tests - # this is just for completeness, safety, and consistency. - pytest.param( - "devices/fake/device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake/device%2Fmessages%2Fdevicebound", - "fake/device", - id="URL encoding of '/' character required for device_id", - ), - ], - ) - def test_url_encodes(self, topic, device_id): - assert mqtt_topic_iothub.is_c2d_topic(topic, device_id) - - @pytest.mark.it("Converts the device id to string when matching to the topic") - def test_str_conversion(self): - topic = "devices/2000/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2F2000%2Fmessages%2Fdevicebound" - device_id = 2000 - assert mqtt_topic_iothub.is_c2d_topic(topic, device_id) - - @pytest.mark.it("Returns False if the provided topic is not a C2D topic") - @pytest.mark.parametrize( - "topic, device_id", - [ - pytest.param("not a topic", "fake_device", id="Not a topic"), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", - "fake_device", - id="Topic of wrong type", - ), - pytest.param( - "devices/fake_device/msgs/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", - "fake_device", - id="Malformed topic", - ), - ], - ) - def test_is_not_c2d_topic(self, topic, device_id): - assert not mqtt_topic_iothub.is_c2d_topic(topic, device_id) - - @pytest.mark.it( - "Returns False if the provided topic is a C2D topic, but does not match the provided device id" - ) - def test_is_c2d_topic_but_wrong_device_id(self): - topic = "devices/fake_device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound" - device_id = "VERY_fake_device" - assert not mqtt_topic_iothub.is_c2d_topic(topic, device_id) - - -@pytest.mark.describe(".is_input_topic()") -class TestIsInputTopic(object): - @pytest.mark.it( - "Returns True if the provided topic is an input topic and matches the provided device id and module id" - ) - def test_is_input_topic(self): - topic = "devices/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input" - device_id = "fake_device" - module_id = "fake_module" - assert mqtt_topic_iothub.is_input_topic(topic, device_id, module_id) - - # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have - # to follow that and not do URL encoding for safety. As a result, some of the values used in - # this test would actually be invalid in production due to character restrictions on the Hub - # that exist to prevent Hub from breaking due to a lack of URL decoding. - # If Hub does begin to support robust URL encoding for safety, this test can easily be switched - # to show that URL encoding DOES work. - @pytest.mark.it("Does NOT URL encode the device id and module_id when matching to the topic") - @pytest.mark.parametrize( - "topic, device_id, module_id", - [ - pytest.param( - "devices/fake?device/modules/fake$module/inputs/fake%23input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake%3Fdevice%2Fmodules%2Ffake%24module%2Finputs%2Ffake%23input", - "fake?device", - "fake$module", - id="Standard URL encoding required for ids", - ), - pytest.param( - "devices/fake device/modules/fake module/inputs/fake%20input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake%20device%2Fmodules%2Ffake%20module%2Finputs%2Ffake%20input", - "fake device", - "fake module", - id="URL encoding for ' ' character required for ids", - ), - pytest.param( - "devices/fake/device/modules/fake/module/inputs/fake%20input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake%2Fdevice%2Fmodules%2Ffake%2Fmodule%2Finputs%2Ffake%2Finput", - "fake/device", - "fake/module", - id="URL encoding for '/' character required for ids", - ), - ], - ) - def test_url_encodes(self, topic, device_id, module_id): - assert mqtt_topic_iothub.is_input_topic(topic, device_id, module_id) - - @pytest.mark.it("Converts the device_id and module_id to string when matching to the topic") - def test_str_conversion(self): - topic = "devices/2000/modules/4000/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2F2000%2Fmodules%2F4000%2Finputs%2Ffake_input" - device_id = 2000 - module_id = 4000 - assert mqtt_topic_iothub.is_input_topic(topic, device_id, module_id) - - @pytest.mark.it("Returns False if the provided topic is not an input topic") - @pytest.mark.parametrize( - "topic, device_id, module_id", - [ - pytest.param("not a topic", "fake_device", "fake_module", id="Not a topic"), - pytest.param( - "devices/fake_device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", - "fake_device", - "fake_module", - id="Topic of wrong type", - ), - pytest.param( - "deivces/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input", - "fake_device", - "fake_module", - id="Malformed topic", - ), - ], - ) - def test_is_not_input_topic(self, topic, device_id, module_id): - assert not mqtt_topic_iothub.is_input_topic(topic, device_id, module_id) - - @pytest.mark.it( - "Returns False if the provided topic is an input topic, but does match the provided device id and/or module_id" - ) - @pytest.mark.parametrize( - "device_id, module_id", - [ - pytest.param("VERY_fake_device", "fake_module", id="Non-matching device_id"), - pytest.param("fake_device", "VERY_fake_module", id="Non-matching module_id"), - pytest.param( - "VERY_fake_device", "VERY_fake_module", id="Non-matching device_id AND module_id" - ), - pytest.param(None, "fake_module", id="No device_id"), - pytest.param("fake_device", None, id="No module_id"), - ], - ) - def test_is_input_topic_but_wrong_id(self, device_id, module_id): - topic = "devices/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input" - assert not mqtt_topic_iothub.is_input_topic(topic, device_id, module_id) - - -@pytest.mark.describe(".is_method_topic()") -class TestIsMethodTopic(object): - @pytest.mark.it("Returns True if the provided topic is a method topic") - def test_is_method_topic(self): - topic = "$iothub/methods/POST/fake_method/?$rid=1" - assert mqtt_topic_iothub.is_method_topic(topic) - - @pytest.mark.it("Returns False if the provided topic is not a method topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param( - "devices/fake_device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", - id="Topic of wrong type", - ), - pytest.param("$iothub/mthds/POST/fake_method/?$rid=1", id="Malformed topic"), - ], - ) - def test_is_not_method_topic(self, topic): - assert not mqtt_topic_iothub.is_method_topic(topic) - - -@pytest.mark.describe(".is_twin_response_topic()") -class TestIsTwinResponseTopic(object): - @pytest.mark.it("Returns True if the provided topic is a twin response topic") - def test_is_twin_response_topic(self): - topic = "$iothub/twin/res/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5" - assert mqtt_topic_iothub.is_twin_response_topic(topic) - - @pytest.mark.it("Returns False if the provided topic is not a twin response topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param("$iothub/methods/POST/fake_method/?$rid=1", id="Topic of wrong type"), - pytest.param( - "$iothub/twin/rs/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5", - id="Malformed topic", - ), - ], - ) - def test_is_not_twin_response_topic(self, topic): - assert not mqtt_topic_iothub.is_twin_response_topic(topic) - - -@pytest.mark.describe(".is_twin_desired_property_patch_topic()") -class TestIsTwinDesiredPropertyPatchTopic(object): - @pytest.mark.it("Returns True if the provided topic is a desired property patch topic") - def test_is_desired_property_patch_topic(self): - topic = "$iothub/twin/PATCH/properties/desired/?$version=1" - assert mqtt_topic_iothub.is_twin_desired_property_patch_topic(topic) - - @pytest.mark.it("Returns False if the provided topic is not a desired property patch topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param("$iothub/methods/POST/fake_method/?$rid=1", id="Topic of wrong type"), - pytest.param("$iothub/twin/PATCH/properties/dsiered/?$version=1", id="Malformed topic"), - ], - ) - def test_is_not_desired_property_patch_topic(self, topic): - assert not mqtt_topic_iothub.is_twin_desired_property_patch_topic(topic) - - -@pytest.mark.describe(".get_input_name_from_topic()") -class TestGetInputNameFromTopic(object): - @pytest.mark.it("Returns the input name from an input topic") - def test_valid_input_topic(self): - topic = "devices/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input" - expected_input_name = "fake_input" - - assert mqtt_topic_iothub.get_input_name_from_topic(topic) == expected_input_name - - @pytest.mark.it("URL decodes the returned input name") - @pytest.mark.parametrize( - "topic, expected_input_name", - [ - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake%24input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake%24input", - "fake$input", - id="Standard URL Decoding", - ), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake+input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake+input", - "fake+input", - id="Does NOT decode '+' character", - ), - ], - ) - def test_url_decodes_value(self, topic, expected_input_name): - assert mqtt_topic_iothub.get_input_name_from_topic(topic) == expected_input_name - - @pytest.mark.it("Raises a ValueError if the provided topic is not an input name topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param("$iothub/methods/POST/fake_method/?$rid=1", id="Topic of wrong type"), - pytest.param( - "devices/fake_device/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input", - id="Malformed topic", - ), - ], - ) - def test_invalid_input_topic(self, topic): - with pytest.raises(ValueError): - mqtt_topic_iothub.get_input_name_from_topic(topic) - - -@pytest.mark.describe(".get_method_name_from_topic()") -class TestGetMethodNameFromTopic(object): - @pytest.mark.it("Returns the method name from a method topic") - def test_valid_method_topic(self): - topic = "$iothub/methods/POST/fake_method/?$rid=1" - expected_method_name = "fake_method" - - assert mqtt_topic_iothub.get_method_name_from_topic(topic) == expected_method_name - - @pytest.mark.it("URL decodes the returned method name") - @pytest.mark.parametrize( - "topic, expected_method_name", - [ - pytest.param( - "$iothub/methods/POST/fake%24method/?$rid=1", - "fake$method", - id="Standard URL Decoding", - ), - pytest.param( - "$iothub/methods/POST/fake+method/?$rid=1", - "fake+method", - id="Does NOT decode '+' character", - ), - ], - ) - def test_url_decodes_value(self, topic, expected_method_name): - assert mqtt_topic_iothub.get_method_name_from_topic(topic) == expected_method_name - - @pytest.mark.it("Raises a ValueError if the provided topic is not a method topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input", - id="Topic of wrong type", - ), - pytest.param("$iothub/methdos/POST/fake_method/?$rid=1", id="Malformed topic"), - ], - ) - def test_invalid_method_topic(self, topic): - with pytest.raises(ValueError): - mqtt_topic_iothub.get_method_name_from_topic(topic) - - -@pytest.mark.describe(".get_method_request_id_from_topic()") -class TestGetMethodRequestIdFromTopic(object): - @pytest.mark.it("Returns the request id from a method topic") - def test_valid_method_topic(self): - topic = "$iothub/methods/POST/fake_method/?$rid=1" - expected_request_id = "1" - - assert mqtt_topic_iothub.get_method_request_id_from_topic(topic) == expected_request_id - - @pytest.mark.it("URL decodes the returned value") - @pytest.mark.parametrize( - "topic, expected_request_id", - [ - pytest.param( - "$iothub/methods/POST/fake_method/?$rid=fake%24request%2Fid", - "fake$request/id", - id="Standard URL Decoding", - ), - pytest.param( - "$iothub/methods/POST/fake_method/?$rid=fake+request+id", - "fake+request+id", - id="Does NOT decode '+' character", - ), - ], - ) - def test_url_decodes_value(self, topic, expected_request_id): - assert mqtt_topic_iothub.get_method_request_id_from_topic(topic) == expected_request_id - - @pytest.mark.it("Raises a ValueError if the provided topic is not a method topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input", - id="Topic of wrong type", - ), - pytest.param("$iothub/methdos/POST/fake_method/?$rid=1", id="Malformed topic"), - ], - ) - def test_invalid_method_topic(self, topic): - with pytest.raises(ValueError): - mqtt_topic_iothub.get_method_request_id_from_topic(topic) - - -@pytest.mark.describe(".get_twin_request_id_from_topic()") -class TestGetTwinRequestIdFromTopic(object): - @pytest.mark.it("Returns the request id from a twin response topic") - def test_valid_twin_response_topic(self): - topic = "$iothub/twin/res/200/?rid=1" - expected_request_id = "1" - - assert mqtt_topic_iothub.get_twin_request_id_from_topic(topic) == expected_request_id - - @pytest.mark.it("URL decodes the returned value") - @pytest.mark.parametrize( - "topic, expected_request_id", - [ - pytest.param( - "$iothub/twin/res/200/?rid=fake%24request%2Fid", - "fake$request/id", - id="Standard URL Decoding", - ), - pytest.param( - "$iothub/twin/res/200/?rid=fake+request+id", - "fake+request+id", - id="Does NOT decode '+' character", - ), - ], - ) - def test_url_decodes_value(self, topic, expected_request_id): - assert mqtt_topic_iothub.get_twin_request_id_from_topic(topic) == expected_request_id - - @pytest.mark.it("Raises a ValueError if the provided topic is not a twin response topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input", - id="Topic of wrong type", - ), - pytest.param("$iothub/twn/res/200?rid=1", id="Malformed topic"), - ], - ) - def test_invalid_twin_response_topic(self, topic): - with pytest.raises(ValueError): - mqtt_topic_iothub.get_twin_request_id_from_topic(topic) - - -@pytest.mark.describe(".get_twin_status_code_from_topic()") -class TestGetTwinStatusCodeFromTopic(object): - @pytest.mark.it("Returns the status from a twin response topic") - def test_valid_twin_response_topic(self): - topic = "$iothub/twin/res/200/?rid=1" - expected_status = "200" - - assert mqtt_topic_iothub.get_twin_status_code_from_topic(topic) == expected_status - - @pytest.mark.it("URL decodes the returned value") - @pytest.mark.parametrize( - "topic, expected_status", - [ - pytest.param("$iothub/twin/res/%24%24%24/?rid=1", "$$$", id="Standard URL decoding"), - pytest.param( - "$iothub/twin/res/invalid+status/?rid=1", - "invalid+status", - id="Does NOT decode '+' character", - ), - ], - ) - def test_url_decode(self, topic, expected_status): - assert mqtt_topic_iothub.get_twin_status_code_from_topic(topic) == expected_status - - @pytest.mark.it("Raises a ValueError if the provided topic is not a twin response topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input", - id="Topic of wrong type", - ), - pytest.param("$iothub/twn/res/200?rid=1", id="Malformed topic"), - ], - ) - def test_invalid_twin_response_topic(self, topic): - with pytest.raises(ValueError): - mqtt_topic_iothub.get_twin_request_id_from_topic(topic) - - -@pytest.mark.describe(".extract_message_properties_from_topic()") -class TestExtractMessagePropertiesFromTopic(object): - @pytest.mark.it("Adds properties from topic to Message object") - @pytest.mark.parametrize( - "topic, expected_system_properties, expected_custom_properties", - [ - pytest.param( - "devices/fake_device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", - {"mid": "6b822696-f75a-46f5-8b02-0680db65abf5"}, - {}, - id="C2D message topic, Mandatory system properties", - ), - pytest.param( - "devices/fake_device/messages/devicebound/%24.exp=3237-07-19T23%3A06%3A40.0000000Z&%24.cid=fake_corid&%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound&%24.ct=fake_content_type&%24.ce=utf-8&iothub-ack=positive", - { - "mid": "6b822696-f75a-46f5-8b02-0680db65abf5", - "exp": "3237-07-19T23:06:40.0000000Z", - "cid": "fake_corid", - "ct": "fake_content_type", - "ce": "utf-8", - "iothub-ack": "positive", - }, - {}, - id="C2D message topic, All system properties", - ), - pytest.param( - "devices/fake_device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound&custom1=value1&custom2=value2&custom3=value3", - {"mid": "6b822696-f75a-46f5-8b02-0680db65abf5"}, - {"custom1": "value1", "custom2": "value2", "custom3": "value3"}, - id="C2D message topic, Custom properties", - ), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input", - {"mid": "6b822696-f75a-46f5-8b02-0680db65abf5"}, - {}, - id="Input message topic, Mandatory system properties", - ), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input/%24.exp=3237-07-19T23%3A06%3A40.0000000Z&%24.cid=fake_corid&%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input&%24.ct=fake_content_type&%24.ce=utf-8&iothub-ack=positive", - { - "mid": "6b822696-f75a-46f5-8b02-0680db65abf5", - "exp": "3237-07-19T23:06:40.0000000Z", - "cid": "fake_corid", - "ct": "fake_content_type", - "ce": "utf-8", - "iothub-ack": "positive", - }, - {}, - id="Input message topic, All system properties", - ), - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input&custom1=value1&custom2=value2&custom3=value3", - {"mid": "6b822696-f75a-46f5-8b02-0680db65abf5"}, - {"custom1": "value1", "custom2": "value2", "custom3": "value3"}, - id="Input message topic, Custom properties", - ), - ], - ) - def test_extracts_properties( - self, topic, expected_system_properties, expected_custom_properties - ): - msg = Message("fake message") - mqtt_topic_iothub.extract_message_properties_from_topic(topic, msg) - - # Validate MANDATORY system properties - assert msg.message_id == expected_system_properties["mid"] - - # Validate OPTIONAL system properties - assert msg.correlation_id == expected_system_properties.get("cid", None) - assert msg.user_id == expected_system_properties.get("uid", None) - assert msg.content_type == expected_system_properties.get("ct", None) - assert msg.content_encoding == expected_system_properties.get("ce", None) - assert msg.expiry_time_utc == expected_system_properties.get("exp", None) - assert msg.ack == expected_system_properties.get( - "iothub-ack", - ) - - # Validate custom properties - assert msg.custom_properties == expected_custom_properties - - @pytest.mark.it("URL decodes properties from the topic when extracting") - @pytest.mark.parametrize( - "topic, expected_system_properties, expected_custom_properties", - [ - pytest.param( - "devices/fake%24device/messages/devicebound/%24.exp=3237-07-19T23%3A06%3A40.0000000Z&%24.cid=fake%23corid&%24.mid=message%24id&%24.to=%2Fdevices%2Ffake%24device%2Fmessages%2Fdevicebound&%24.ct=fake%23content%24type&%24.ce=utf-%24&iothub-ack=po%24itive&custom%2A=value%23&custom%26=value%24&custom%25=value%40", - { - "mid": "message$id", - "exp": "3237-07-19T23:06:40.0000000Z", - "cid": "fake#corid", - "ct": "fake#content$type", - "ce": "utf-$", - "iothub-ack": "po$itive", - }, - {"custom*": "value#", "custom&": "value$", "custom%": "value@"}, - id="C2D message topic, Standard URL decoding", - ), - pytest.param( - "devices/fake+device/messages/devicebound/%24.exp=3237-07-19T23%3A06%3A40.0000000Z&%24.cid=fake+corid&%24.mid=message+id&%24.to=%2Fdevices%2Ffake+device%2Fmessages%2Fdevicebound&%24.ct=fake+content+type&%24.ce=utf-+&iothub-ack=posi+ive&custom+1=value+1&custom+2=value+2&custom+3=value+3", - { - "mid": "message+id", - "exp": "3237-07-19T23:06:40.0000000Z", - "cid": "fake+corid", - "ct": "fake+content+type", - "ce": "utf-+", - "iothub-ack": "posi+ive", - }, - {"custom+1": "value+1", "custom+2": "value+2", "custom+3": "value+3"}, - id="C2D message topic, does NOT decode '+' character", - ), - pytest.param( - "devices/fake%24device/modules/fake%23module/inputs/fake%25input/%24.exp=3237-07-19T23%3A06%3A40.0000000Z&%24.cid=fake%23corid&%24.mid=message%24id&%24.to=%2Fdevices%2Ffake%24device%2Fmodules%2Ffake%23module%2Finputs%2Ffake%25input&%24.ct=fake%23content%24type&%24.ce=utf-%24&iothub-ack=po%24itive&custom%2A=value%23&custom%26=value%24&custom%25=value%40", - { - "mid": "message$id", - "exp": "3237-07-19T23:06:40.0000000Z", - "cid": "fake#corid", - "ct": "fake#content$type", - "ce": "utf-$", - "iothub-ack": "po$itive", - }, - {"custom*": "value#", "custom&": "value$", "custom%": "value@"}, - id="Input message topic, Standard URL decoding", - ), - pytest.param( - "devices/fake+device/modules/fake+module/inputs/fake+input/%24.exp=3237-07-19T23%3A06%3A40.0000000Z&%24.cid=fake+corid&%24.mid=message+id&%24.to=%2Fdevices%2Ffake+device%2Fmodules%2Ffake+module%2Finputs%2Ffake+input&%24.ct=fake+content+type&%24.ce=utf-+&iothub-ack=posi+ive&custom+1=value+1&custom+2=value+2&custom+3=value+3", - { - "mid": "message+id", - "exp": "3237-07-19T23:06:40.0000000Z", - "cid": "fake+corid", - "ct": "fake+content+type", - "ce": "utf-+", - "iothub-ack": "posi+ive", - }, - {"custom+1": "value+1", "custom+2": "value+2", "custom+3": "value+3"}, - id="Input message topic, does NOT decode '+' character", - ), - ], - ) - def test_url_decode(self, topic, expected_system_properties, expected_custom_properties): - msg = Message("fake message") - mqtt_topic_iothub.extract_message_properties_from_topic(topic, msg) - - # Validate MANDATORY system properties - assert msg.message_id == expected_system_properties["mid"] - - # Validate OPTIONAL system properties - assert msg.correlation_id == expected_system_properties.get("cid", None) - assert msg.user_id == expected_system_properties.get("uid", None) - assert msg.content_type == expected_system_properties.get("ct", None) - assert msg.content_encoding == expected_system_properties.get("ce", None) - assert msg.expiry_time_utc == expected_system_properties.get("exp", None) - - # Validate custom properties - assert msg.custom_properties == expected_custom_properties - - @pytest.mark.it("Ignores certain properties in a C2D message topic, and does NOT extract them") - @pytest.mark.parametrize( - "topic", - [ - pytest.param( - "devices/fake_device/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", - id="$.to", - ), - ], - ) - def test_ignores_on_c2d(self, topic): - msg = Message("fake message") - mqtt_topic_iothub.extract_message_properties_from_topic(topic, msg) - assert msg.custom_properties == {} - - @pytest.mark.it( - "Ignores certain properties in an input message topic, and does NOT extract them" - ) - @pytest.mark.parametrize( - "topic", - [ - pytest.param( - "devices/fake_device/modules/fake_module/inputs/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input", - id="$.to", - ), - ], - ) - def test_ignores_on_input_message(self, topic): - msg = Message("fake message") - mqtt_topic_iothub.extract_message_properties_from_topic(topic, msg) - assert msg.custom_properties == {} - - @pytest.mark.it( - "Raises a ValueError if the provided topic is not a c2d topic or an input message topic" - ) - @pytest.mark.parametrize( - "topic", - [ - pytest.param("not a topic", id="Not a topic"), - pytest.param( - "$iothub/twin/res/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5", - id="Topic of wrong type", - ), - pytest.param( - "devices/fake_device/messages/devicebnd/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", - id="Malformed C2D topic", - ), - pytest.param( - "devices/fake_device/modules/fake_module/inutps/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input", - id="Malformed input message topic", - ), - ], - ) - def test_bad_topic(self, topic): - msg = Message("fake message") - with pytest.raises(ValueError): - mqtt_topic_iothub.extract_message_properties_from_topic(topic, msg) - - @pytest.mark.it("Extracts system and custom properties without values") - @pytest.mark.parametrize( - "topic, extracted_system_properties, extracted_custom_properties", - [ - pytest.param( - "devices/fakedevice/messages/devicebound/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=1.0&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid&%24.uid", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": None, "uid": None}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "dataVersion": "1.0", - "$.cdid": "fakecdid", - }, - id="C2D topic with some system properties not having values", - ), - pytest.param( - "devices/fakedevice/modules/fakemodule/inputs/fakeinput/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=1.0&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid&%24.uid", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": None, "uid": None}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "dataVersion": "1.0", - "$.cdid": "fakecdid", - }, - id="Input message topic with some system properties not having values", - ), - pytest.param( - "devices/fakedevice/messages/devicebound/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=fakecorrid&%24.uid=harrypotter&classname", - { - "mid": "e32c2285-668e-4161-a236-9f5f6b90362c", - "cid": "fakecorrid", - "uid": "harrypotter", - }, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "classname": None, - "dataVersion": None, - }, - id="C2D topic with some custom properties not having values", - ), - pytest.param( - "devices/fakedevice/modules/fakemodule/inputs/fakeinput/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=fakecorrid&%24.uid=harrypotter&classname", - { - "mid": "e32c2285-668e-4161-a236-9f5f6b90362c", - "cid": "fakecorrid", - "uid": "harrypotter", - }, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "classname": None, - "dataVersion": None, - }, - id="Input message topic with some custom properties not having values", - ), - pytest.param( - "devices/fakedevice/messages/devicebound/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid&%24.uid&classname", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": None, "uid": None}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "classname": None, - "dataVersion": None, - }, - id="C2D topic with some system properties and some custom not having values", - ), - pytest.param( - "devices/fakedevice/modules/fakemodule/inputs/fakeinput/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid&%24.uid&classname", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": None, "uid": None}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "classname": None, - "dataVersion": None, - }, - id="Input message topic with some system properties and some custom not having values", - ), - ], - ) - def test_receive_topic_without_values( - self, topic, extracted_system_properties, extracted_custom_properties - ): - msg = Message("fake message") - mqtt_topic_iothub.extract_message_properties_from_topic(topic, msg) - - # Validate system properties received - assert msg.message_id == extracted_system_properties["mid"] - assert msg.correlation_id == extracted_system_properties["cid"] - assert msg.user_id == extracted_system_properties["uid"] - - # Validate system properties NOT received - assert msg.content_type == extracted_system_properties.get("ct", None) - assert msg.content_encoding == extracted_system_properties.get("ce", None) - assert msg.expiry_time_utc == extracted_system_properties.get("exp", None) - - # Validate custom properties - assert msg.custom_properties == extracted_custom_properties - - @pytest.mark.it("Extracts system and custom properties with empty string values") - @pytest.mark.parametrize( - "topic, extracted_system_properties, extracted_custom_properties", - [ - pytest.param( - "devices/fakedevice/messages/devicebound/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=1.0&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=&%24.uid=", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": "", "uid": ""}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "dataVersion": "1.0", - "$.cdid": "fakecdid", - }, - id="C2D topic with some system properties not having values", - ), - pytest.param( - "devices/fakedevice/modules/fakemodule/inputs/fakeinput/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=1.0&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=&%24.uid=", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": "", "uid": ""}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "dataVersion": "1.0", - "$.cdid": "fakecdid", - }, - id="Input message topic with some system properties not having values", - ), - pytest.param( - "devices/fakedevice/messages/devicebound/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=fakecorrid&%24.uid=harrypotter&classname=", - { - "mid": "e32c2285-668e-4161-a236-9f5f6b90362c", - "cid": "fakecorrid", - "uid": "harrypotter", - }, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "dataVersion": "", - "classname": "", - }, - id="C2D topic with some custom properties not having values", - ), - pytest.param( - "devices/fakedevice/modules/fakemodule/inputs/fakeinput/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=fakecorrid&%24.uid=harrypotter&classname=", - { - "mid": "e32c2285-668e-4161-a236-9f5f6b90362c", - "cid": "fakecorrid", - "uid": "harrypotter", - }, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "dataVersion": "", - "classname": "", - }, - id="Input message topic with some custom properties not having values", - ), - pytest.param( - "devices/fakedevice/messages/devicebound/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=&%24.uid=&classname=", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": "", "uid": ""}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "dataVersion": "", - "classname": "", - }, - id="C2D topic with some system properties and some custom not having values", - ), - pytest.param( - "devices/fakedevice/modules/fakemodule/inputs/fakeinput/topic=%2Fsubscriptions%2FresourceGroups&subject=%2FgraphInstances&dataVersion=&%24.cdid=fakecdid&%24.mid=e32c2285-668e-4161-a236-9f5f6b90362c&%24.cid=&%24.uid=&classname=", - {"mid": "e32c2285-668e-4161-a236-9f5f6b90362c", "cid": "", "uid": ""}, - { - "topic": "/subscriptions/resourceGroups", - "subject": "/graphInstances", - "$.cdid": "fakecdid", - "dataVersion": "", - "classname": "", - }, - id="Input message topic with some system properties and some custom not having values", - ), - ], - ) - def test_receive_topic_with_empty_values( - self, topic, extracted_system_properties, extracted_custom_properties - ): - msg = Message("fake message") - mqtt_topic_iothub.extract_message_properties_from_topic(topic, msg) - - # Validate system properties received - assert msg.message_id == extracted_system_properties["mid"] - assert msg.correlation_id == extracted_system_properties["cid"] - assert msg.user_id == extracted_system_properties["uid"] - - # Validate system properties NOT received - assert msg.content_type == extracted_system_properties.get("ct", None) - assert msg.content_encoding == extracted_system_properties.get("ce", None) - assert msg.expiry_time_utc == extracted_system_properties.get("exp", None) - - # Validate custom properties - assert msg.custom_properties == extracted_custom_properties - - -@pytest.mark.describe(".encode_message_properties_in_topic()") -class TestEncodeMessagePropertiesInTopic(object): - def create_message(self, system_properties, custom_properties): - m = Message("payload") - m.message_id = system_properties.get("mid") - m.correlation_id = system_properties.get("cid") - m.user_id = system_properties.get("uid") - m.output_name = system_properties.get("on") - m.content_encoding = system_properties.get("ce") - m.content_type = system_properties.get("ct") - m.expiry_time_utc = system_properties.get("exp") - if system_properties.get("ifid"): - m.set_as_security_message() - m.custom_properties = custom_properties - return m - - @pytest.fixture(params=["C2D Message", "Input Message"]) - def message_topic(self, request): - if request.param == "C2D Message": - return "devices/fake_device/messages/events/" - else: - return "devices/fake_device/modules/fake_module/messages/events/" - - @pytest.mark.it( - "Returns a new version of the given topic string that contains message properties from the given message" - ) - @pytest.mark.parametrize( - "message_system_properties, message_custom_properties, expected_encoding", - [ - pytest.param({}, {}, "", id="No properties"), - pytest.param( - {"mid": "1234", "ce": "utf-8"}, - {}, - "%24.mid=1234&%24.ce=utf-8", - id="Some System Properties", - ), - pytest.param( - { - "mid": "1234", - "cid": "5678", - "uid": "userid", - "on": "output", - "ce": "utf-8", - "ct": "type", - "exp": datetime.datetime(2019, 2, 2), - "ifid": True, - }, - {}, - "%24.on=output&%24.mid=1234&%24.cid=5678&%24.uid=userid&%24.ct=type&%24.ce=utf-8&%24.ifid=urn%3Aazureiot%3ASecurity%3ASecurityAgent%3A1&%24.exp=2019-02-02T00%3A00%3A00", - id="All System Properties", - ), - pytest.param( - {}, - {"custom1": "value1", "custom2": "value2", "custom3": "value3"}, - "custom1=value1&custom2=value2&custom3=value3", - id="Custom Properties ONLY", - ), - pytest.param( - { - "mid": "1234", - "cid": "5678", - "uid": "userid", - "on": "output", - "ce": "utf-8", - "ct": "type", - "exp": datetime.datetime(2019, 2, 2), - "ifid": True, - }, - {"custom1": "value1", "custom2": "value2", "custom3": "value3"}, - "%24.on=output&%24.mid=1234&%24.cid=5678&%24.uid=userid&%24.ct=type&%24.ce=utf-8&%24.ifid=urn%3Aazureiot%3ASecurity%3ASecurityAgent%3A1&%24.exp=2019-02-02T00%3A00%3A00&custom1=value1&custom2=value2&custom3=value3", - id="System Properties AND Custom Properties", - ), - ], - ) - def test_encodes_properties( - self, message_topic, message_system_properties, message_custom_properties, expected_encoding - ): - message = self.create_message(message_system_properties, message_custom_properties) - encoded_topic = mqtt_topic_iothub.encode_message_properties_in_topic(message, message_topic) - - assert encoded_topic.startswith(message_topic) - encoding = encoded_topic.split(message_topic)[1] - assert encoding == expected_encoding - - @pytest.mark.it("URL encodes message properties when adding them to the topic") - @pytest.mark.parametrize( - "message_system_properties, message_custom_properties, expected_encoding", - [ - pytest.param( - { - "mid": "message#id", - "cid": "correlation#id", - "uid": "user#id", - "on": "some#output", - "ce": "utf-#", - "ct": "fake#type", - "exp": datetime.datetime(2019, 2, 2), - "ifid": True, - }, - {"custom#1": "value#1", "custom#2": "value#2", "custom#3": "value#3"}, - "%24.on=some%23output&%24.mid=message%23id&%24.cid=correlation%23id&%24.uid=user%23id&%24.ct=fake%23type&%24.ce=utf-%23&%24.ifid=urn%3Aazureiot%3ASecurity%3ASecurityAgent%3A1&%24.exp=2019-02-02T00%3A00%3A00&custom%231=value%231&custom%232=value%232&custom%233=value%233", - id="Standard URL Encoding", - ), - pytest.param( - { - "mid": "message id", - "cid": "correlation id", - "uid": "user id", - "on": "some output", - "ce": "utf- ", - "ct": "fake type", - "exp": datetime.datetime(2019, 2, 2), - "ifid": True, - }, - {"custom 1": "value 1", "custom 2": "value 2", "custom 3": "value 3"}, - "%24.on=some%20output&%24.mid=message%20id&%24.cid=correlation%20id&%24.uid=user%20id&%24.ct=fake%20type&%24.ce=utf-%20&%24.ifid=urn%3Aazureiot%3ASecurity%3ASecurityAgent%3A1&%24.exp=2019-02-02T00%3A00%3A00&custom%201=value%201&custom%202=value%202&custom%203=value%203", - id="URL Encoding of ' ' character", - ), - pytest.param( - { - "mid": "message/id", - "cid": "correlation/id", - "uid": "user/id", - "on": "some/output", - "ce": "utf-/", - "ct": "fake/type", - "exp": datetime.datetime(2019, 2, 2), - "ifid": True, - }, - {"custom/1": "value/1", "custom/2": "value/2", "custom/3": "value/3"}, - "%24.on=some%2Foutput&%24.mid=message%2Fid&%24.cid=correlation%2Fid&%24.uid=user%2Fid&%24.ct=fake%2Ftype&%24.ce=utf-%2F&%24.ifid=urn%3Aazureiot%3ASecurity%3ASecurityAgent%3A1&%24.exp=2019-02-02T00%3A00%3A00&custom%2F1=value%2F1&custom%2F2=value%2F2&custom%2F3=value%2F3", - id="URL Encoding of '/' character", - ), - ], - ) - def test_url_encodes( - self, message_topic, message_system_properties, message_custom_properties, expected_encoding - ): - message = self.create_message(message_system_properties, message_custom_properties) - encoded_topic = mqtt_topic_iothub.encode_message_properties_in_topic(message, message_topic) - - assert encoded_topic.startswith(message_topic) - encoding = encoded_topic.split(message_topic)[1] - assert encoding == expected_encoding - - @pytest.mark.it("String converts message properties when adding them to the topic") - def test_str_conversion(self, message_topic): - system_properties = {"mid": 1234, "cid": 5678, "uid": 4000, "on": 2222, "ce": 8, "ct": 12} - custom_properties = {1: 23, 47: 245, 3000: 9458} - expected_encoding = "%24.on=2222&%24.mid=1234&%24.cid=5678&%24.uid=4000&%24.ct=12&%24.ce=8&1=23&3000=9458&47=245" - message = self.create_message(system_properties, custom_properties) - encoded_topic = mqtt_topic_iothub.encode_message_properties_in_topic(message, message_topic) - - assert encoded_topic.startswith(message_topic) - encoding = encoded_topic.split(message_topic)[1] - assert encoding == expected_encoding - - @pytest.mark.it( - "Raises ValueError if duplicate keys exist in custom properties due to string conversion" - ) - def test_duplicate_keys(self, message_topic): - system_properties = {} - custom_properties = {1: "val1", "1": "val2"} - message = self.create_message(system_properties, custom_properties) - - with pytest.raises(ValueError): - mqtt_topic_iothub.encode_message_properties_in_topic(message, message_topic) diff --git a/tests/unit/iothub/pipeline/test_pipeline_events_iothub.py b/tests/unit/iothub/pipeline/test_pipeline_events_iothub.py deleted file mode 100644 index 8c5b370f3..000000000 --- a/tests/unit/iothub/pipeline/test_pipeline_events_iothub.py +++ /dev/null @@ -1,37 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import sys -import logging -from azure.iot.device.iothub.pipeline import pipeline_events_iothub -from tests.unit.common.pipeline import pipeline_event_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] - -pipeline_event_test.add_event_test( - cls=pipeline_events_iothub.C2DMessageEvent, - module=this_module, - positional_arguments=["message"], - keyword_arguments={}, -) -pipeline_event_test.add_event_test( - cls=pipeline_events_iothub.InputMessageEvent, - module=this_module, - positional_arguments=["message"], - keyword_arguments={}, -) -pipeline_event_test.add_event_test( - cls=pipeline_events_iothub.MethodRequestEvent, - module=this_module, - positional_arguments=["method_request"], - keyword_arguments={}, -) -pipeline_event_test.add_event_test( - cls=pipeline_events_iothub.TwinDesiredPropertiesPatchEvent, - module=this_module, - positional_arguments=["patch"], - keyword_arguments={}, -) diff --git a/tests/unit/iothub/pipeline/test_pipeline_ops_iothub.py b/tests/unit/iothub/pipeline/test_pipeline_ops_iothub.py deleted file mode 100644 index d2ddb5b30..000000000 --- a/tests/unit/iothub/pipeline/test_pipeline_ops_iothub.py +++ /dev/null @@ -1,148 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import sys -import logging -from azure.iot.device.iothub.pipeline import pipeline_ops_iothub -from tests.unit.common.pipeline import pipeline_ops_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -class SendD2CMessageOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.SendD2CMessageOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"message": mocker.MagicMock(), "callback": mocker.MagicMock()} - return kwargs - - -class SendD2CMessageOperationInstantiationTests(SendD2CMessageOperationTestConfig): - @pytest.mark.it("Initializes 'message' attribute with the provided 'message' parameter") - def test_message(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.message is init_kwargs["message"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.SendD2CMessageOperation, - op_test_config_class=SendD2CMessageOperationTestConfig, - extended_op_instantiation_test_class=SendD2CMessageOperationInstantiationTests, -) - - -class SendOutputMessageOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.SendOutputMessageOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"message": mocker.MagicMock(), "callback": mocker.MagicMock()} - return kwargs - - -class SendOutputMessageOperationInstantiationTests(SendOutputMessageOperationTestConfig): - @pytest.mark.it("Initializes 'message' attribute with the provided 'message' parameter") - def test_message(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.message is init_kwargs["message"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.SendOutputMessageOperation, - op_test_config_class=SendOutputMessageOperationTestConfig, - extended_op_instantiation_test_class=SendOutputMessageOperationInstantiationTests, -) - - -class SendMethodResponseOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.SendMethodResponseOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"method_response": mocker.MagicMock(), "callback": mocker.MagicMock()} - return kwargs - - -class SendMethodResponseOperationInstantiationTests(SendMethodResponseOperationTestConfig): - @pytest.mark.it( - "Initializes 'method_response' attribute with the provided 'method_response' parameter" - ) - def test_method_response(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.method_response is init_kwargs["method_response"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.SendMethodResponseOperation, - op_test_config_class=SendMethodResponseOperationTestConfig, - extended_op_instantiation_test_class=SendMethodResponseOperationInstantiationTests, -) - - -class GetTwinOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.GetTwinOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"callback": mocker.MagicMock()} - return kwargs - - -class GetTwinOperationInstantiationTests(GetTwinOperationTestConfig): - @pytest.mark.it("Initializes 'twin' attribute as None") - def test_twin(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.twin is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.GetTwinOperation, - op_test_config_class=GetTwinOperationTestConfig, - extended_op_instantiation_test_class=GetTwinOperationInstantiationTests, -) - - -class PatchTwinReportedPropertiesOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.PatchTwinReportedPropertiesOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"patch": {"some": "patch"}, "callback": mocker.MagicMock()} - return kwargs - - -class PatchTwinReportedPropertiesOperationInstantiationTests( - PatchTwinReportedPropertiesOperationTestConfig -): - @pytest.mark.it("Initializes 'patch' attribute with the provided 'patch' parameter") - def test_patch(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.patch is init_kwargs["patch"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.PatchTwinReportedPropertiesOperation, - op_test_config_class=PatchTwinReportedPropertiesOperationTestConfig, - extended_op_instantiation_test_class=PatchTwinReportedPropertiesOperationInstantiationTests, -) diff --git a/tests/unit/iothub/pipeline/test_pipeline_ops_iothub_http.py b/tests/unit/iothub/pipeline/test_pipeline_ops_iothub_http.py deleted file mode 100644 index f7115003c..000000000 --- a/tests/unit/iothub/pipeline/test_pipeline_ops_iothub_http.py +++ /dev/null @@ -1,149 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import sys -import logging -from azure.iot.device.iothub.pipeline import pipeline_ops_iothub_http -from tests.unit.common.pipeline import pipeline_ops_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - -fake_device_id = "__fake_device_id__" -fake_module_id = "__fake_module_id__" - - -class MethodInvokeOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub_http.MethodInvokeOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "target_device_id": fake_device_id, - "target_module_id": fake_module_id, - "method_params": mocker.MagicMock(), - "callback": mocker.MagicMock(), - } - return kwargs - - -class MethodInvokeOperationInstantiationTests(MethodInvokeOperationTestConfig): - @pytest.mark.it("Initializes 'device_id' attribute with the provided 'device_id' parameter") - def test_device_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.target_device_id is init_kwargs["target_device_id"] - - @pytest.mark.it("Initializes 'module_id' attribute with the provided 'module_id' parameter") - def test_module_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.target_module_id is init_kwargs["target_module_id"] - - @pytest.mark.it( - "Initializes 'method_params' attribute with the provided 'method_params' parameter" - ) - def test_method_params(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.method_params is init_kwargs["method_params"] - - @pytest.mark.it("Initializes 'method_response' attribute as None") - def test_method_response(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.method_response is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub_http.MethodInvokeOperation, - op_test_config_class=MethodInvokeOperationTestConfig, - extended_op_instantiation_test_class=MethodInvokeOperationInstantiationTests, -) - - -class GetStorageInfoOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub_http.GetStorageInfoOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"blob_name": "__fake_blob_name__", "callback": mocker.MagicMock()} - return kwargs - - -class GetStorageInfoOperationInstantiationTests(GetStorageInfoOperationTestConfig): - @pytest.mark.it("Initializes 'blob_name' attribute with the provided 'blob_name' parameter") - def test_blob_name(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.blob_name is init_kwargs["blob_name"] - - @pytest.mark.it("Initializes 'storage_info' attribute as None") - def test_storage_info(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.storage_info is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub_http.GetStorageInfoOperation, - op_test_config_class=GetStorageInfoOperationTestConfig, - extended_op_instantiation_test_class=GetStorageInfoOperationInstantiationTests, -) - - -class NotifyBlobUploadStatusOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "correlation_id": "__fake_correlation_id__", - "is_success": "__fake_is_success__", - "status_code": "__fake_status_code__", - "status_description": "__fake_status_description__", - "callback": mocker.MagicMock(), - } - return kwargs - - -class NotifyBlobUploadStatusOperationInstantiationTests(NotifyBlobUploadStatusOperationTestConfig): - @pytest.mark.it( - "Initializes 'correlation_id' attribute with the provided 'correlation_id' parameter" - ) - def test_correlation_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.correlation_id is init_kwargs["correlation_id"] - - @pytest.mark.it("Initializes 'is_success' attribute with the provided 'is_success' parameter") - def test_is_success(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.is_success is init_kwargs["is_success"] - - @pytest.mark.it( - "Initializes 'request_status_code' attribute with the provided 'status_code' parameter" - ) - def test_request_status_code(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_status_code is init_kwargs["status_code"] - - @pytest.mark.it( - "Initializes 'status_description' attribute with the provided 'status_description' parameter" - ) - def test_status_description(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.status_description is init_kwargs["status_description"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation, - op_test_config_class=NotifyBlobUploadStatusOperationTestConfig, - extended_op_instantiation_test_class=NotifyBlobUploadStatusOperationInstantiationTests, -) diff --git a/tests/unit/iothub/pipeline/test_pipeline_stages_iothub.py b/tests/unit/iothub/pipeline/test_pipeline_stages_iothub.py deleted file mode 100644 index 5de141f61..000000000 --- a/tests/unit/iothub/pipeline/test_pipeline_stages_iothub.py +++ /dev/null @@ -1,1130 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import json -import logging -import pytest -import sys -from azure.iot.device.exceptions import ServiceError -from azure.iot.device.iothub.pipeline import ( - pipeline_events_iothub, - pipeline_ops_iothub, - pipeline_stages_iothub, - constant as pipeline_constants, -) -from azure.iot.device.common.pipeline import ( - pipeline_ops_base, - pipeline_events_base, -) -from tests.unit.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase -from tests.unit.common.pipeline import pipeline_stage_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -fake_device_id = "__fake_device_id__" -fake_module_id = "__fake_module_id__" -fake_hostname = "__fake_hostname__" -fake_gateway_hostname = "__fake_gateway_hostname__" -fake_server_verification_cert = "__fake_server_verification_cert__" -fake_sas_token = "__fake_sas_token__" -fake_symmetric_key = "Zm9vYmFy" -fake_x509_cert_file = "fake_certificate_file" -fake_x509_cert_key_file = "fake_certificate_key_file" -fake_pass_phrase = "fake_pass_phrase" - - -################### -# COMMON FIXTURES # -################### - - -@pytest.fixture(params=[True, False], ids=["With error", "No error"]) -def op_error(request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - -######################################### -# ENSURE DESIRED PROPERTIES STAGE STAGE # -######################################### - - -class EnsureDesiredPropertiesStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_iothub.EnsureDesiredPropertiesStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - # Tests are going to assume ensure_desired_properties is true, as most tests will need - # it to be true - stage.nucleus.pipeline_configuration.ensure_desired_properties = True - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -class EnsureDesiredPropertiesStageInstantiationTests(EnsureDesiredPropertiesStageTestConfig): - @pytest.mark.it("Initializes 'last_version_seen' None") - def test_last_version_seen(self, init_kwargs): - stage = pipeline_stages_iothub.EnsureDesiredPropertiesStage(**init_kwargs) - assert stage.last_version_seen is None - - @pytest.mark.it("Initializes 'pending_get_request' None") - def test_pending_get_request(self, init_kwargs): - stage = pipeline_stages_iothub.EnsureDesiredPropertiesStage(**init_kwargs) - assert stage.pending_get_request is None - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_iothub.EnsureDesiredPropertiesStage, - stage_test_config_class=EnsureDesiredPropertiesStageTestConfig, - extended_stage_instantiation_test_class=EnsureDesiredPropertiesStageInstantiationTests, -) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .run_op() -- Called with EnableFeatureOperation (ensure_desired_properties enabled)" -) -class TestEnsureDesiredPropertiesStageRunOpWithEnableFeatureOperationWithEnsureDesiredPropertiesEnabled( - StageRunOpTestBase, EnsureDesiredPropertiesStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.EnableFeatureOperation( - feature_name="fake_feature_name", callback=mocker.MagicMock() - ) - - @pytest.mark.it("Sets `last_version_seen` to -1 if `op.feature_name` is 'twin_patches'") - def test_sets_last_version_seen(self, mocker, stage, op): - op.feature_name = pipeline_constants.TWIN_PATCHES - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - assert stage.last_version_seen is None - stage.run_op(op) - - assert stage.last_version_seen == -1 - - @pytest.mark.parametrize( - "feature_name", - [ - pipeline_constants.C2D_MSG, - pipeline_constants.INPUT_MSG, - pipeline_constants.METHODS, - pipeline_constants.TWIN, - ], - ) - @pytest.mark.it( - "Does not change `last_version_seen` if `op.feature_name` is not 'twin_patches'" - ) - def test_doesnt_set_last_version_seen(self, mocker, stage, op, feature_name): - op.feature_name = feature_name - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.last_version_seen = mocker.MagicMock() - - old_value = stage.last_version_seen - stage.run_op(op) - - assert stage.last_version_seen == old_value - - @pytest.mark.parametrize( - "feature_name", - [ - pipeline_constants.C2D_MSG, - pipeline_constants.INPUT_MSG, - pipeline_constants.METHODS, - pipeline_constants.TWIN, - pipeline_constants.TWIN_PATCHES, - ], - ) - @pytest.mark.it( - "Sends the EnableFeatureOperation op to the next stage for all valid `op.feature_name` values" - ) - def test_passes_all_other_features_down(self, mocker, stage, op, feature_name): - op.feature_name = feature_name - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .run_op() -- Called with other arbitrary operation (ensure_desired_properties enabled)" -) -class TestEnsureDesiredPropertiesStageRunOpWithArbitraryOperationWithEnsureDesiredPropertiesEnabled( - StageRunOpTestBase, EnsureDesiredPropertiesStageTestConfig -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .run_op() -- Called with EnableFeatureOperation (ensure_desired_properties disabled)" -) -class TestEnsureDesiredPropertiesStageRunOpWithEnableFeatureOperationWithEnsureDesiredPropertiesDisabled( - StageRunOpTestBase, EnsureDesiredPropertiesStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.EnableFeatureOperation( - feature_name="fake_feature_name", callback=mocker.MagicMock() - ) - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - # Overriding the parent class - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = False - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.mark.parametrize( - "feature_name", - [ - pipeline_constants.C2D_MSG, - pipeline_constants.INPUT_MSG, - pipeline_constants.METHODS, - pipeline_constants.TWIN, - pipeline_constants.TWIN_PATCHES, - ], - ) - @pytest.mark.it( - "Sends the EnableFeatureOperation op to the next stage for all valid `op.feature_name` values" - ) - def test_passes_all_other_features_down(self, mocker, stage, op, feature_name): - op.feature_name = feature_name - assert stage.nucleus.pipeline_configuration.ensure_desired_properties is False - - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - @pytest.mark.it( - "Does not change `last_version_seen`, even if `op.feature_name` is twin_patches'" - ) - def test_doesnt_set_last_version_seen(self, mocker, stage, op): - op.feature_name = pipeline_constants.TWIN_PATCHES - assert stage.nucleus.pipeline_configuration.ensure_desired_properties is False - stage.last_version_seen = mocker.MagicMock() - - old_value = stage.last_version_seen - stage.run_op(op) - - assert stage.last_version_seen == old_value - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .run_op() -- Called with other arbitrary operation (ensure_desired_properties disabled)" -) -class TestEnsureDesiredPropertiesStageRunOpWithArbitraryOperationWithEnsureDesiredPropertiesDisabled( - StageRunOpTestBase, EnsureDesiredPropertiesStageTestConfig -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.nucleus.pipeline_configuration.ensure_desired_properties = False - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .handle_pipeline_event() -- Called with ConnectedEvent (ensure_desired_properties enabled)" -) -class TestEnsureDesiredPropertiesStageWhenConnectedEventReceivedWithEnsureDesiredPropertiesEnabled( - EnsureDesiredPropertiesStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = True - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def event(self): - return pipeline_events_base.ConnectedEvent() - - @pytest.mark.it( - "Sends a GetTwinOperation if last_version_seen is set and there is no pending GetTwinOperation" - ) - def test_last_version_seen_no_pending(self, mocker, stage, event): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.last_version_seen = mocker.MagicMock() - stage.pending_get_request = None - - stage.handle_pipeline_event(event) - - assert stage.send_op_down.call_count == 1 - assert isinstance(stage.send_op_down.call_args[0][0], pipeline_ops_iothub.GetTwinOperation) - - @pytest.mark.it( - "Does not send a GetTwinOperation if last version seen is set and there is already a pending GetTwinOperation" - ) - def test_last_version_seen_pending(self, mocker, stage, event): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.last_version_seen = mocker.MagicMock() - stage.pending_get_request = mocker.MagicMock() - - stage.handle_pipeline_event(event) - - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Does not send a GetTwinOperation if last_version_seen is not set and there is no pending GetTwinOperation" - ) - def test_no_last_version_seen_no_pending(self, mocker, stage, event): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.last_version_seen = None - stage.pending_get_request = None - - stage.handle_pipeline_event(event) - - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Does not send a GetTwinOperation if last version seen is not set and there is already a pending GetTwinOperation" - ) - def test_no_last_version_seen_pending(self, mocker, stage, event): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.last_version_seen = None - stage.pending_get_request = mocker.MagicMock() - - stage.handle_pipeline_event(event) - - assert stage.send_op_down.call_count == 0 - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .handle_pipeline_event() -- Called with ConnectedEvent (ensure_desired_properties disabled)" -) -class TestEnsureDesiredPropertiesStageWhenConnectedEventReceivedWithEnsureDesiredPropertiesDisabled( - EnsureDesiredPropertiesStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = False - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def event(self): - return pipeline_events_base.ConnectedEvent() - - @pytest.mark.it("Does not send a GetTwinOperation if Ensure_Desired_Properties is disabled") - def test_no_get_twin_op(self, stage, event): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties is False - stage.last_version_seen = None - - stage.handle_pipeline_event(event) - - assert stage.send_op_down.call_count == 0 - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .handle_pipeline_event() -- Called with TwinDesiredPropertiesPatchEvent (ensure_desired_properties enabled)" -) -class TestEnsureDesiredPropertiesStageWhenTwinDesiredPropertiesPatchEventReceivedWithEnsureDesiredPropertiesEnabled( - EnsureDesiredPropertiesStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = True - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def version(self, mocker): - return mocker.MagicMock() - - @pytest.fixture - def event(self, version): - return pipeline_events_iothub.TwinDesiredPropertiesPatchEvent(patch={"$version": version}) - - @pytest.mark.it("Saves the `$version` attribute of the patch into `last_version_seen`") - def test_saves_the_last_version_seen(self, mocker, stage, event, version): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.last_version_seen = mocker.MagicMock() - - stage.handle_pipeline_event(event) - - assert stage.last_version_seen == version - - @pytest.mark.it("Sends the event to the previous stage") - def test_sends_event_up(self, mocker, stage, event, version): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .handle_pipeline_event() -- Called with TwinDesiredPropertiesPatchEvent (ensure_desired_properties disabled)" -) -class TestEnsureDesiredPropertiesStageWhenTwinDesiredPropertiesPatchEventReceivedWithEnsureDesiredPropertiesDisabled( - EnsureDesiredPropertiesStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = False - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def version(self, mocker): - return mocker.MagicMock() - - @pytest.fixture - def event(self, version): - return pipeline_events_iothub.TwinDesiredPropertiesPatchEvent(patch={"$version": version}) - - @pytest.mark.it("Does not change `last_version_seen`") - def test_doesnt_save_the_last_version_seen(self, mocker, stage, event, version): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties is False - stage.last_version_seen = mocker.MagicMock() - - old_version = stage.last_version_seen - stage.handle_pipeline_event(event) - - assert stage.last_version_seen == old_version - - @pytest.mark.it("Sends the event to the previous stage") - def test_sends_event_up(self, mocker, stage, event, version): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties is False - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .handle_pipeline_event() -- Called with other arbitrary event (ensure_desired_properties enabled)" -) -class TestEnsureDesiredPropertiesStageWhenArbitraryEventReceivedWithEnsureDesiredPropertiesEnabled( - EnsureDesiredPropertiesStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = True - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def version(self, mocker): - return mocker.MagicMock() - - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it("Sends the event to the previous stage") - def test_sends_event_up(self, mocker, stage, event): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - .handle_pipeline_event() -- Called with other arbitrary event (ensure_desired_properties disabled)" -) -class TestEnsureDesiredPropertiesStageWhenArbitraryEventReceivedWithEnsureDesiredPropertiesDisabled( - EnsureDesiredPropertiesStageTestConfig, StageHandlePipelineEventTestBase -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = False - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def version(self, mocker): - return mocker.MagicMock() - - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it("Sends the event to the previous stage") - def test_sends_event_up(self, mocker, stage, event): - assert stage.nucleus.pipeline_configuration.ensure_desired_properties is False - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - OCCURRENCE: GetTwinOperation that was sent down by this stage completes and pipeline is connected" -) -class TestEnsureDesiredPropertiesStageWhenGetTwinOperationCompletesConnected( - EnsureDesiredPropertiesStageTestConfig -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = True - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def get_twin_op(self, stage): - stage.last_version_seen = -1 - stage.handle_pipeline_event(pipeline_events_base.ConnectedEvent()) - - get_twin_op = stage.send_op_down.call_args[0][0] - assert isinstance(get_twin_op, pipeline_ops_iothub.GetTwinOperation) - - stage.send_op_down.reset_mock() - stage.send_event_up.reset_mock() - - return get_twin_op - - @pytest.fixture - def new_version(self): - return 1234 - - @pytest.fixture - def new_twin(self, new_version): - return {"desired": {"$version": new_version}, "reported": {}} - - @pytest.mark.it("Does not send a new GetTwinOperation if the op completes with success") - def test_does_not_send_new_get_twin_operation_on_success( - self, stage, get_twin_op, new_twin, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it("Sets `pending_get_request` to None if the op completes with success") - def test_sets_pending_request_to_none_on_success( - self, mocker, stage, get_twin_op, new_twin, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - stage.pending_get_request = mocker.MagicMock() - - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.pending_get_request is None - - @pytest.mark.it("Sends a new GetTwinOperation if the op completes with an error") - def test_sends_new_get_twin_operation_on_failure( - self, stage, get_twin_op, arbitrary_exception, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - assert stage.send_op_down.call_count == 0 - get_twin_op.complete(error=arbitrary_exception) - - assert stage.send_op_down.call_count == 1 - assert isinstance(stage.send_op_down.call_args[0][0], pipeline_ops_iothub.GetTwinOperation) - - @pytest.mark.it( - "Sets `pending_get_request` to the new GetTwinOperation if the op completes with an error" - ) - def test_sets_pending_request_to_none_on_failure( - self, mocker, stage, get_twin_op, arbitrary_exception, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - old_get_request = mocker.MagicMock() - stage.pending_get_request = old_get_request - - get_twin_op.complete(error=arbitrary_exception) - - assert stage.pending_get_request is not old_get_request - assert isinstance(stage.pending_get_request, pipeline_ops_iothub.GetTwinOperation) - - @pytest.mark.it( - "Does not send a `TwinDesiredPropertiesPatchEvent` if the op completes with an error" - ) - def test_doesnt_send_patch_event_if_error( - self, stage, get_twin_op, arbitrary_exception, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - get_twin_op.complete(arbitrary_exception) - - assert stage.send_event_up.call_count == 0 - - @pytest.mark.it( - "Sends a `TwinDesiredPropertiesPatchEvent` if the desired properties '$version' doesn't match the `last_version_seen`" - ) - def test_sends_patch_event_if_different_version( - self, mocker, stage, get_twin_op, new_twin, new_version, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - stage.last_version_seen = mocker.MagicMock() - - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.send_event_up.call_count == 1 - assert isinstance( - stage.send_event_up.call_args[0][0], - pipeline_events_iothub.TwinDesiredPropertiesPatchEvent, - ) - - @pytest.mark.it( - "Does not send a `TwinDesiredPropertiesPatchEvent` if the desired properties '$version' matches the `last_version_seen`" - ) - def test_doesnt_send_patch_event_if_same_version( - self, stage, get_twin_op, new_twin, new_version, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - stage.last_version_seen = new_version - - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.send_event_up.call_count == 0 - - @pytest.mark.it( - "Does not change the `last_version_seen` attribute if the op completes with an error" - ) - def test_doesnt_change_last_version_seen_if_error( - self, mocker, stage, get_twin_op, arbitrary_exception, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - old_version = mocker.MagicMock() - stage.last_version_seen = old_version - - get_twin_op.complete(error=arbitrary_exception) - - assert stage.last_version_seen == old_version - - @pytest.mark.it( - "Sets the `last_version_seen` attribute to the new version if the desired properties '$version' doesn't match the `last_version_seen`" - ) - def test_changes_last_version_seen_if_different_version( - self, mocker, stage, get_twin_op, new_twin, new_version, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - stage.last_version_seen = mocker.MagicMock() - - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.last_version_seen == new_version - - @pytest.mark.it( - "Does not change the `last_version_seen` attribute if the desired properties '$version' matches the `last_version_seen`" - ) - def test_does_not_change_last_version_seen_if_same_version( - self, stage, get_twin_op, new_twin, new_version, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = True - stage.last_version_seen = new_version - - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.last_version_seen == new_version - - -@pytest.mark.describe( - "EnsureDesiredPropertiesStage - OCCURRENCE: GetTwinOperation that was sent down by this stage completes and pipeline is NOT connected" -) -class TestEnsureDesiredPropertiesStageWhenGetTwinOperationCompletesNotConnected( - EnsureDesiredPropertiesStageTestConfig -): - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration.ensure_desired_properties = True - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - @pytest.fixture - def get_twin_op(self, stage): - stage.last_version_seen = -1 - assert stage.send_op_down.call_count == 0 - stage.handle_pipeline_event(pipeline_events_base.ConnectedEvent()) - - get_twin_op = stage.send_op_down.call_args[0][0] - assert isinstance(get_twin_op, pipeline_ops_iothub.GetTwinOperation) - assert stage.send_op_down.call_count == 1 - - stage.send_op_down.reset_mock() - stage.send_event_up.reset_mock() - - return get_twin_op - - @pytest.fixture - def new_version(self): - return 1234 - - @pytest.fixture - def new_twin(self, new_version): - return {"desired": {"$version": new_version}, "reported": {}} - - @pytest.mark.it("Does not send a new GetTwinOperation if the op completes with success") - def test_does_not_send_new_get_twin_operation_on_success( - self, stage, get_twin_op, new_twin, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = False - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it("Sets `pending_get_request` to None if the op completes with success") - def test_sets_pending_request_to_none_on_success( - self, mocker, stage, get_twin_op, new_twin, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = False - stage.pending_get_request = mocker.MagicMock() - - get_twin_op.twin = new_twin - get_twin_op.complete() - - assert stage.pending_get_request is None - - @pytest.mark.it("Does not send a new GetTwinOperation if the op completes with an error") - def test_doesnt_send_new_get_twin_operation_on_failure( - self, stage, get_twin_op, arbitrary_exception, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = False - assert stage.send_op_down.call_count == 0 - get_twin_op.complete(error=arbitrary_exception) - assert stage.send_op_down.call_count == 0 - - @pytest.mark.it( - "Does not send a `TwinDesiredPropertiesPatchEvent` if the op completes with an error" - ) - def test_doesnt_send_patch_event_if_error( - self, stage, get_twin_op, arbitrary_exception, pipeline_connected_mock - ): - pipeline_connected_mock.return_value = False - get_twin_op.complete(arbitrary_exception) - - assert stage.send_event_up.call_count == 0 - - -############################### -# TWIN REQUEST RESPONSE STAGE # -############################### - - -class TwinRequestResponseStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_iothub.TwinRequestResponseStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_iothub.TwinRequestResponseStage, - stage_test_config_class=TwinRequestResponseStageTestConfig, -) - - -@pytest.mark.describe("TwinRequestResponseStage - .run_op() -- Called with GetTwinOperation") -class TestTwinRequestResponseStageRunOpWithGetTwinOperation( - StageRunOpTestBase, TwinRequestResponseStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_iothub.GetTwinOperation(callback=mocker.MagicMock()) - - @pytest.mark.it( - "Sends a new RequestAndResponseOperation down the pipeline, configured to request a twin" - ) - def test_request_and_response_op(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) - assert new_op.request_type == "twin" - assert new_op.method == "GET" - assert new_op.resource_location == "/" - assert new_op.request_body == " " - - -@pytest.mark.describe( - "TwinRequestResponseStage - .run_op() -- Called with PatchTwinReportedPropertiesOperation" -) -class TestTwinRequestResponseStageRunOpWithPatchTwinReportedPropertiesOperation( - StageRunOpTestBase, TwinRequestResponseStageTestConfig -): - @pytest.fixture(params=["Dictionary Patch", "String Patch", "Integer Patch", "None Patch"]) - def json_patch(self, request): - if request.param == "Dictionary Patch": - return {"json_key": "json_val"} - elif request.param == "String Patch": - return "some_json" - elif request.param == "Integer Patch": - return 1234 - elif request.param == "None Patch": - return None - - @pytest.fixture - def op(self, mocker, json_patch): - return pipeline_ops_iothub.PatchTwinReportedPropertiesOperation( - patch=json_patch, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sends a new RequestAndResponseOperation down the pipeline, configured to send a twin reported properties patch, with the patch serialized as a JSON string" - ) - def test_request_and_response_op(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) - assert new_op.request_type == "twin" - assert new_op.method == "PATCH" - assert new_op.resource_location == "/properties/reported/" - assert new_op.request_body == json.dumps(op.patch) - - -@pytest.mark.describe( - "TwinRequestResponseStage - .run_op() -- Called with other arbitrary operation" -) -class TestTwinRequestResponseStageRunOpWithArbitraryOperation( - StageRunOpTestBase, TwinRequestResponseStageTestConfig -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -# TODO: Provide a more accurate set of status codes for tests -@pytest.mark.describe( - "TwinRequestResponseStage - OCCURRENCE: RequestAndResponseOperation created from GetTwinOperation is completed" -) -class TestTwinRequestResponseStageWhenRequestAndResponseCreatedFromGetTwinOperationCompleted( - TwinRequestResponseStageTestConfig -): - @pytest.fixture - def get_twin_op(self, mocker): - return pipeline_ops_iothub.GetTwinOperation(callback=mocker.MagicMock()) - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, get_twin_op): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - - # Run the GetTwinOperation - stage.run_op(get_twin_op) - - return stage - - @pytest.fixture - def request_and_response_op(self, stage): - assert stage.send_op_down.call_count == 1 - op = stage.send_op_down.call_args[0][0] - assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) - - # reset the stage mock for convenience - stage.send_op_down.reset_mock() - - return op - - @pytest.mark.it( - "Completes the GetTwinOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" - ) - @pytest.mark.parametrize( - "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(None, id="Status Code: None"), - pytest.param(200, id="Status Code: 200"), - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - def test_request_and_response_op_completed_with_err( - self, - stage, - get_twin_op, - request_and_response_op, - arbitrary_exception, - status_code, - has_response_body, - ): - assert not get_twin_op.completed - assert not request_and_response_op.completed - - # NOTE: It shouldn't happen that an operation completed with error has a status code or a - # response body, but it IS possible. - request_and_response_op.status_code = status_code - if has_response_body: - request_and_response_op.response_body = b'{"key": "value"}' - request_and_response_op.complete(error=arbitrary_exception) - - assert request_and_response_op.completed - assert request_and_response_op.error is arbitrary_exception - assert get_twin_op.completed - assert get_twin_op.error is arbitrary_exception - # Twin is NOT returned - assert get_twin_op.twin is None - - @pytest.mark.it( - "Completes the GetTwinOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed successfully with a status code indicating an unsuccessful result from the service" - ) - @pytest.mark.parametrize( - "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - def test_request_and_response_op_completed_success_with_bad_code( - self, stage, get_twin_op, request_and_response_op, status_code, has_response_body - ): - assert not get_twin_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = status_code - if has_response_body: - request_and_response_op.response_body = b'{"key": "value"}' - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert get_twin_op.completed - assert isinstance(get_twin_op.error, ServiceError) - # Twin is NOT returned - assert get_twin_op.twin is None - - @pytest.mark.it( - "Completes the GetTwinOperation successfully (with the JSON deserialized response body from the RequestAndResponseOperation as the twin) if the RequestAndResponseOperation is completed successfully with a status code indicating a successful result from the service" - ) - @pytest.mark.parametrize( - "response_body, expected_twin", - [ - pytest.param(b'{"key": "value"}', {"key": "value"}, id="Twin 1"), - pytest.param(b'{"key1": {"key2": "value"}}', {"key1": {"key2": "value"}}, id="Twin 2"), - pytest.param( - b'{"key1": {"key2": {"key3": "value1", "key4": "value2"}, "key5": "value3"}, "key6": {"key7": "value4"}, "key8": "value5"}', - { - "key1": {"key2": {"key3": "value1", "key4": "value2"}, "key5": "value3"}, - "key6": {"key7": "value4"}, - "key8": "value5", - }, - id="Twin 3", - ), - ], - ) - def test_request_and_response_op_completed_success_with_good_code( - self, stage, get_twin_op, request_and_response_op, response_body, expected_twin - ): - assert not get_twin_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.response_body = response_body - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert get_twin_op.completed - assert get_twin_op.error is None - assert get_twin_op.twin == expected_twin - - -@pytest.mark.describe( - "TwinRequestResponseStage - OCCURRENCE: RequestAndResponseOperation created from PatchTwinReportedPropertiesOperation is completed" -) -class TestTwinRequestResponseStageWhenRequestAndResponseCreatedFromPatchTwinReportedPropertiesOperation( - TwinRequestResponseStageTestConfig -): - @pytest.fixture - def patch_twin_reported_properties_op(self, mocker): - return pipeline_ops_iothub.PatchTwinReportedPropertiesOperation( - patch={"json_key": "json_val"}, callback=mocker.MagicMock() - ) - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, patch_twin_reported_properties_op): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - - # Run the GetTwinOperation - stage.run_op(patch_twin_reported_properties_op) - - return stage - - @pytest.fixture - def request_and_response_op(self, stage): - assert stage.send_op_down.call_count == 1 - op = stage.send_op_down.call_args[0][0] - assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) - - # reset the stage mock for convenience - stage.send_op_down.reset_mock() - - return op - - @pytest.mark.it( - "Completes the PatchTwinReportedPropertiesOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(None, id="Status Code: None"), - pytest.param(200, id="Status Code: 200"), - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - def test_request_and_response_op_completed_with_err( - self, - stage, - patch_twin_reported_properties_op, - request_and_response_op, - arbitrary_exception, - status_code, - ): - assert not patch_twin_reported_properties_op.completed - assert not request_and_response_op.completed - - # NOTE: It shouldn't happen that an operation completed with error has a status code - # but it IS possible - request_and_response_op.status_code = status_code - request_and_response_op.complete(error=arbitrary_exception) - - assert request_and_response_op.completed - assert request_and_response_op.error is arbitrary_exception - assert patch_twin_reported_properties_op.completed - assert patch_twin_reported_properties_op.error is arbitrary_exception - - @pytest.mark.it( - "Completes the PatchTwinReportedPropertiesOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed successfully with a status code indicating an unsuccessful result from the service" - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - def test_request_and_response_op_completed_success_with_bad_code( - self, stage, patch_twin_reported_properties_op, request_and_response_op, status_code - ): - assert not patch_twin_reported_properties_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = status_code - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert patch_twin_reported_properties_op.completed - assert isinstance(patch_twin_reported_properties_op.error, ServiceError) - - @pytest.mark.it( - "Completes the PatchTwinReportedPropertiesOperation successfully if the RequestAndResponseOperation is completed successfully with a status code indicating a successful result from the service" - ) - def test_request_and_response_op_completed_success_with_good_code( - self, stage, patch_twin_reported_properties_op, request_and_response_op - ): - assert not patch_twin_reported_properties_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert patch_twin_reported_properties_op.completed - assert patch_twin_reported_properties_op.error is None diff --git a/tests/unit/iothub/pipeline/test_pipeline_stages_iothub_http.py b/tests/unit/iothub/pipeline/test_pipeline_stages_iothub_http.py deleted file mode 100644 index a5a6ba08b..000000000 --- a/tests/unit/iothub/pipeline/test_pipeline_stages_iothub_http.py +++ /dev/null @@ -1,756 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import json -import sys -import urllib -from azure.iot.device.common.pipeline import pipeline_ops_http -from azure.iot.device.iothub.pipeline import ( - pipeline_ops_iothub_http, - pipeline_stages_iothub_http, - config, -) -from azure.iot.device.exceptions import ServiceError -from tests.unit.common.pipeline.helpers import StageRunOpTestBase -from tests.unit.common.pipeline import pipeline_stage_test -from azure.iot.device import constant as pkg_constant -from azure.iot.device import user_agent - -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -this_module = sys.modules[__name__] - -################### -# COMMON FIXTURES # -################### - - -@pytest.fixture(params=[True, False], ids=["With error", "No error"]) -def op_error(request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - -@pytest.fixture -def mock_http_path_iothub(mocker): - mock = mocker.patch( - "azure.iot.device.iothub.pipeline.pipeline_stages_iothub_http.http_path_iothub" - ) - return mock - - -################################## -# IOT HUB HTTP TRANSLATION STAGE # -################################## - - -class IoTHubHTTPTranslationStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_iothub_http.IoTHubHTTPTranslationStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def pipeline_config(self, mocker): - # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. - # Manually override to make this for modules - cfg = config.IoTHubPipelineConfig( - hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() - ) - return cfg - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus, pipeline_config): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration = pipeline_config - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_iothub_http.IoTHubHTTPTranslationStage, - stage_test_config_class=IoTHubHTTPTranslationStageTestConfig, -) - - -@pytest.mark.describe( - "IoTHubHTTPTranslationStage - .run_op() -- Called with MethodInvokeOperation op" -) -class TestIoTHubHTTPTranslationStageRunOpCalledWithMethodInvokeOperation( - IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def pipeline_config(self, mocker): - # Because Method related functionality is limited to Module, configure the stage for a module - # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. - cfg = config.IoTHubPipelineConfig( - hostname="http://my.hostname", - gateway_hostname="http://my.gateway.hostname", - device_id="my_device", - module_id="my_module", - sastoken=mocker.MagicMock(), - ) - return cfg - - @pytest.fixture(params=["Targeting Device Method", "Targeting Module Method"]) - def op(self, mocker, request): - method_params = {"arg1": "val", "arg2": 2, "arg3": True} - if request.param == "Targeting Device Method": - return pipeline_ops_iothub_http.MethodInvokeOperation( - target_device_id="fake_target_device_id", - target_module_id=None, - method_params=method_params, - callback=mocker.MagicMock(), - ) - else: - return pipeline_ops_iothub_http.MethodInvokeOperation( - target_device_id="fake_target_device_id", - target_module_id="fake_target_module_id", - method_params=method_params, - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Sends a new HTTPRequestAndResponseOperation op down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with request details for sending a Method Invoke request" - ) - def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothub): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate request - assert mock_http_path_iothub.get_method_invoke_path.call_count == 1 - assert mock_http_path_iothub.get_method_invoke_path.call_args == mocker.call( - op.target_device_id, op.target_module_id - ) - expected_path = mock_http_path_iothub.get_method_invoke_path.return_value - - assert new_op.method == "POST" - assert new_op.path == expected_path - assert new_op.query_params == "api-version={}".format(pkg_constant.IOTHUB_API_VERSION) - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with the headers for a Method Invoke request" - ) - @pytest.mark.parametrize( - "custom_user_agent", - [ - pytest.param("", id="No custom user agent"), - pytest.param("MyCustomUserAgent", id="With custom user agent"), - pytest.param( - "My/Custom?User+Agent", id="With custom user agent containing reserved characters" - ), - pytest.param(12345, id="Non-string custom user agent"), - ], - ) - def test_new_op_headers(self, mocker, stage, op, custom_user_agent, pipeline_config): - stage.nucleus.pipeline_configuration.product_info = custom_user_agent - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate headers - expected_user_agent = urllib.parse.quote_plus( - user_agent.get_iothub_user_agent() + str(custom_user_agent) - ) - expected_edge_string = "{}/{}".format(pipeline_config.device_id, pipeline_config.module_id) - - assert new_op.headers["Host"] == pipeline_config.gateway_hostname - assert new_op.headers["Content-Type"] == "application/json" - assert new_op.headers["Content-Length"] == str(len(new_op.body)) - assert new_op.headers["x-ms-edge-moduleId"] == expected_edge_string - assert new_op.headers["User-Agent"] == expected_user_agent - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with a body for a Method Invoke request" - ) - def test_new_op_body(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate body - assert new_op.body == json.dumps(op.method_params) - - @pytest.mark.it( - "Completes the original MethodInvokeOperation op (no error) if the new HTTPRequestAndResponseOperation op is completed later on (no error) with a status code indicating success" - ) - def test_new_op_completes_with_good_code(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op - new_op.response_body = b'{"some_response_key": "some_response_value"}' - new_op.status_code = 200 - new_op.complete() - - # Both ops are now completed successfully - assert new_op.completed - assert new_op.error is None - assert op.completed - assert op.error is None - - @pytest.mark.it( - "Deserializes the completed HTTPRequestAndResponseOperation op's 'response_body' (the received storage info) and set it on the MethodInvokeOperation op as the 'method_response', if the HTTPRequestAndResponseOperation is completed later (no error) with a status code indicating success" - ) - @pytest.mark.parametrize( - "response_body, expected_method_response", - [ - pytest.param( - b'{"key": "val"}', {"key": "val"}, id="Response Body: dict value as bytestring" - ), - pytest.param( - b'{"key": "val", "key2": {"key3": "val2"}}', - {"key": "val", "key2": {"key3": "val2"}}, - id="Response Body: dict value as bytestring", - ), - ], - ) - def test_deserializes_response( - self, mocker, stage, op, response_body, expected_method_response - ): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Original op has no 'method_response' - assert op.method_response is None - - # Complete new op - new_op.response_body = response_body - new_op.status_code = 200 - new_op.complete() - - # Method Response is set - assert op.method_response == expected_method_response - - @pytest.mark.it( - "Completes the original MethodInvokeOperation op with a ServiceError if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating non-success" - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - def test_new_op_completes_with_bad_code(self, mocker, stage, op, status_code): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op successfully (but with a bad status code) - new_op.status_code = status_code - new_op.complete() - - # The original op is now completed with a ServiceError - assert new_op.completed - assert new_op.error is None - assert op.completed - assert isinstance(op.error, ServiceError) - - @pytest.mark.it( - "Completes the original MethodInvokeOperation op with the error from the new HTTPRequestAndResponseOperation, if the HTTPRequestAndResponseOperation is completed later on with error" - ) - def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exception): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op with error - new_op.complete(error=arbitrary_exception) - - # The original op is now completed with a ServiceError - assert new_op.completed - assert new_op.error is arbitrary_exception - assert op.completed - assert op.error is arbitrary_exception - - -@pytest.mark.describe( - "IoTHubHTTPTranslationStage - .run_op() -- Called with GetStorageInfoOperation op" -) -class TestIoTHubHTTPTranslationStageRunOpCalledWithGetStorageInfoOperation( - IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def pipeline_config(self, mocker): - # Because Storage/Blob related functionality is limited to Device, configure pipeline for a device - # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. - cfg = config.IoTHubPipelineConfig( - hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() - ) - return cfg - - @pytest.fixture - def op(self, mocker): - return pipeline_ops_iothub_http.GetStorageInfoOperation( - blob_name="fake_blob_name", callback=mocker.MagicMock() - ) - - @pytest.mark.it("Sends a new HTTPRequestAndResponseOperation op down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with request details for sending a Get Storage Info request" - ) - def test_sends_get_storage_request( - self, mocker, stage, op, mock_http_path_iothub, pipeline_config - ): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate request - assert mock_http_path_iothub.get_storage_info_for_blob_path.call_count == 1 - assert mock_http_path_iothub.get_storage_info_for_blob_path.call_args == mocker.call( - pipeline_config.device_id - ) - expected_path = mock_http_path_iothub.get_storage_info_for_blob_path.return_value - - assert new_op.method == "POST" - assert new_op.path == expected_path - assert new_op.query_params == "api-version={}".format(pkg_constant.IOTHUB_API_VERSION) - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with the headers for a Get Storage Info request" - ) - @pytest.mark.parametrize( - "custom_user_agent", - [ - pytest.param("", id="No custom user agent"), - pytest.param("MyCustomUserAgent", id="With custom user agent"), - pytest.param( - "My/Custom?User+Agent", id="With custom user agent containing reserved characters" - ), - pytest.param(12345, id="Non-string custom user agent"), - ], - ) - def test_new_op_headers(self, mocker, stage, op, custom_user_agent, pipeline_config): - stage.nucleus.pipeline_configuration.product_info = custom_user_agent - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate headers - expected_user_agent = urllib.parse.quote_plus( - user_agent.get_iothub_user_agent() + str(custom_user_agent) - ) - - assert new_op.headers["Host"] == pipeline_config.hostname - assert new_op.headers["Accept"] == "application/json" - assert new_op.headers["Content-Type"] == "application/json" - assert new_op.headers["Content-Length"] == str(len(new_op.body)) - assert new_op.headers["User-Agent"] == expected_user_agent - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with a body for a Get Storage Info request" - ) - def test_new_op_body(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate body - assert new_op.body == '{{"blobName": "{}"}}'.format(op.blob_name) - - @pytest.mark.it( - "Completes the original GetStorageInfoOperation op (no error) if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating success" - ) - def test_new_op_completes_with_good_code(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op - new_op.response_body = b'{"json": "response"}' - new_op.status_code = 200 - new_op.complete() - - # Both ops are now completed successfully - assert new_op.completed - assert new_op.error is None - assert op.completed - assert op.error is None - - @pytest.mark.it( - "Deserializes the completed HTTPRequestAndResponseOperation op's 'response_body' (the received storage info) and set it on the GetStorageInfoOperation as the 'storage_info', if the HTTPRequestAndResponseOperation is completed later (no error) with a status code indicating success" - ) - def test_deserializes_response(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Original op has no 'storage_info' - assert op.storage_info is None - - # Complete new op - new_op.response_body = b'{\ - "hostName": "fake_hostname",\ - "containerName": "fake_container_name",\ - "blobName": "fake_blob_name",\ - "sasToken": "fake_sas_token",\ - "correlationId": "fake_correlation_id"\ - }' - new_op.status_code = 200 - new_op.complete() - - # Storage Info is set - assert op.storage_info == { - "hostName": "fake_hostname", - "containerName": "fake_container_name", - "blobName": "fake_blob_name", - "sasToken": "fake_sas_token", - "correlationId": "fake_correlation_id", - } - - @pytest.mark.it( - "Completes the original GetStorageInfoOperation op with a ServiceError if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating non-success" - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - def test_new_op_completes_with_bad_code(self, mocker, stage, op, status_code): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op successfully (but with a bad status code) - new_op.status_code = status_code - new_op.complete() - - # The original op is now completed with a ServiceError - assert new_op.completed - assert new_op.error is None - assert op.completed - assert isinstance(op.error, ServiceError) - - @pytest.mark.it( - "Completes the original GetStorageInfoOperation op with the error from the new HTTPRequestAndResponseOperation, if the HTTPRequestAndResponseOperation is completed later on with error" - ) - def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exception): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op with error - new_op.complete(error=arbitrary_exception) - - # The original op is now completed with a ServiceError - assert new_op.completed - assert new_op.error is arbitrary_exception - assert op.completed - assert op.error is arbitrary_exception - - -@pytest.mark.describe( - "IoTHubHTTPTranslationStage - .run_op() -- Called with NotifyBlobUploadStatusOperation op" -) -class TestIoTHubHTTPTranslationStageRunOpCalledWithNotifyBlobUploadStatusOperation( - IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase -): - @pytest.fixture - def pipeline_config(self, mocker): - # Because Storage/Blob related functionality is limited to Device, configure pipeline for a device - # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. - cfg = config.IoTHubPipelineConfig( - hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() - ) - return cfg - - @pytest.fixture - def op(self, mocker): - return pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation( - correlation_id="fake_correlation_id", - is_success=True, - status_code=203, - status_description="fake_description", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Sends a new HTTPRequestAndResponseOperation op down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with request details for sending a Notify Blob Upload Status request" - ) - def test_sends_get_storage_request( - self, mocker, stage, op, mock_http_path_iothub, pipeline_config - ): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate request - assert mock_http_path_iothub.get_notify_blob_upload_status_path.call_count == 1 - assert mock_http_path_iothub.get_notify_blob_upload_status_path.call_args == mocker.call( - pipeline_config.device_id - ) - expected_path = mock_http_path_iothub.get_notify_blob_upload_status_path.return_value - - assert new_op.method == "POST" - assert new_op.path == expected_path - assert new_op.query_params == "api-version={}".format(pkg_constant.IOTHUB_API_VERSION) - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with the headers for a Notify Blob Upload Status request" - ) - @pytest.mark.parametrize( - "custom_user_agent", - [ - pytest.param("", id="No custom user agent"), - pytest.param("MyCustomUserAgent", id="With custom user agent"), - pytest.param( - "My/Custom?User+Agent", id="With custom user agent containing reserved characters" - ), - pytest.param(12345, id="Non-string custom user agent"), - ], - ) - def test_new_op_headers(self, mocker, stage, op, custom_user_agent, pipeline_config): - stage.nucleus.pipeline_configuration.product_info = custom_user_agent - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate headers - expected_user_agent = urllib.parse.quote_plus( - user_agent.get_iothub_user_agent() + str(custom_user_agent) - ) - - assert new_op.headers["Host"] == pipeline_config.hostname - assert new_op.headers["Content-Type"] == "application/json; charset=utf-8" - assert new_op.headers["Content-Length"] == str(len(new_op.body)) - assert new_op.headers["User-Agent"] == expected_user_agent - - @pytest.mark.it( - "Configures the HTTPRequestAndResponseOperation with a body for a Notify Blob Upload Status request" - ) - def test_new_op_body(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Validate body - header_dict = { - "correlationId": op.correlation_id, - "isSuccess": op.is_success, - "statusCode": op.request_status_code, - "statusDescription": op.status_description, - } - assert new_op.body == json.dumps(header_dict) - - @pytest.mark.it( - "Completes the original NotifyBlobUploadStatusOperation op (no error) if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating success" - ) - def test_new_op_completes_with_good_code(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op - new_op.status_code = 200 - new_op.complete() - - # Both ops are now completed successfully - assert new_op.completed - assert new_op.error is None - assert op.completed - assert op.error is None - - @pytest.mark.it( - "Completes the original NotifyBlobUploadStatusOperation op with a ServiceError if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating non-success" - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - def test_new_op_completes_with_bad_code(self, mocker, stage, op, status_code): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op successfully (but with a bad status code) - new_op.status_code = status_code - new_op.complete() - - # The original op is now completed with a ServiceError - assert new_op.completed - assert new_op.error is None - assert op.completed - assert isinstance(op.error, ServiceError) - - @pytest.mark.it( - "Completes the original NotifyBlobUploadStatusOperation op with the error from the new HTTPRequestAndResponseOperation, if the HTTPRequestAndResponseOperation is completed later on with error" - ) - def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exception): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) - - # Neither op is completed - assert not op.completed - assert op.error is None - assert not new_op.completed - assert new_op.error is None - - # Complete new op with error - new_op.complete(error=arbitrary_exception) - - # The original op is now completed with a ServiceError - assert new_op.completed - assert new_op.error is arbitrary_exception - assert op.completed - assert op.error is arbitrary_exception diff --git a/tests/unit/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py b/tests/unit/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py deleted file mode 100644 index b9e116f53..000000000 --- a/tests/unit/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py +++ /dev/null @@ -1,1024 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import json -import sys -import urllib -from azure.iot.device.common.pipeline import ( - pipeline_events_base, - pipeline_ops_base, - pipeline_ops_mqtt, - pipeline_events_mqtt, -) -from azure.iot.device.iothub.pipeline import ( - constant, - pipeline_events_iothub, - pipeline_ops_iothub, - pipeline_stages_iothub_mqtt, - config, - mqtt_topic_iothub, -) -from azure.iot.device.iothub.pipeline.exceptions import OperationError -from azure.iot.device.iothub.models.message import Message -from azure.iot.device.iothub.models.methods import MethodRequest, MethodResponse -from tests.unit.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase -from tests.unit.common.pipeline import pipeline_stage_test -from azure.iot.device import constant as pkg_constant, user_agent - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread", "mock_mqtt_topic") - - -@pytest.fixture -def mock_mqtt_topic(mocker): - # Don't mock the whole module, just mock what we want to (which is most of it). - # Mocking out the get_x_topic style functions is useful, but the ones that - # match patterns and return bools (is_x_topic) making testing annoying if mocked. - mocker.patch.object(mqtt_topic_iothub, "get_telemetry_topic_for_publish") - mocker.patch.object(mqtt_topic_iothub, "get_method_topic_for_publish") - mocker.patch.object(mqtt_topic_iothub, "get_twin_topic_for_publish") - mocker.patch.object(mqtt_topic_iothub, "get_c2d_topic_for_subscribe") - mocker.patch.object(mqtt_topic_iothub, "get_input_topic_for_subscribe") - mocker.patch.object(mqtt_topic_iothub, "get_method_topic_for_subscribe") - mocker.patch.object(mqtt_topic_iothub, "get_twin_response_topic_for_subscribe") - mocker.patch.object(mqtt_topic_iothub, "get_twin_patch_topic_for_subscribe") - mocker.patch.object(mqtt_topic_iothub, "encode_message_properties_in_topic") - mocker.patch.object(mqtt_topic_iothub, "extract_message_properties_from_topic") - # It's kind of weird that we return the (unmocked) module, but it's easier this way, - # and since it's a module, not a function, we'd never treat it like a mock anyway - # (you don't check the call count of a module) - return mqtt_topic_iothub - - -@pytest.fixture(params=[True, False], ids=["With error", "No error"]) -def op_error(request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - -# NOTE: This fixture is defined out here rather than on a class because it is used for both -# EnableFeatureOperation and DisableFeatureOperation tests -@pytest.fixture -def expected_mqtt_topic_fn(mock_mqtt_topic, iothub_pipeline_feature): - if iothub_pipeline_feature == constant.C2D_MSG: - return mock_mqtt_topic.get_c2d_topic_for_subscribe - elif iothub_pipeline_feature == constant.INPUT_MSG: - return mock_mqtt_topic.get_input_topic_for_subscribe - elif iothub_pipeline_feature == constant.METHODS: - return mock_mqtt_topic.get_method_topic_for_subscribe - elif iothub_pipeline_feature == constant.TWIN: - return mock_mqtt_topic.get_twin_response_topic_for_subscribe - elif iothub_pipeline_feature == constant.TWIN_PATCHES: - return mock_mqtt_topic.get_twin_patch_topic_for_subscribe - else: - # This shouldn't happen - assert False - - -# NOTE: This fixture is defined out here rather than on a class because it is used for both -# EnableFeatureOperation and DisableFeatureOperation tests -@pytest.fixture -def expected_mqtt_topic_fn_call(mocker, iothub_pipeline_feature, stage): - if iothub_pipeline_feature == constant.C2D_MSG: - return mocker.call(stage.nucleus.pipeline_configuration.device_id) - elif iothub_pipeline_feature == constant.INPUT_MSG: - return mocker.call( - stage.nucleus.pipeline_configuration.device_id, - stage.nucleus.pipeline_configuration.module_id, - ) - else: - return mocker.call() - - -class IoTHubMQTTTranslationStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def pipeline_config(self, mocker): - # NOTE 1: auth type shouldn't matter for this stage, so just give it a fake sastoken for now. - # NOTE 2: This config is configured for a device, not a module. Where relevant, override this - # fixture or dynamically add a module_id - cfg = config.IoTHubPipelineConfig( - hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() - ) - return cfg - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus, pipeline_config): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration = pipeline_config - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage, - stage_test_config_class=IoTHubMQTTTranslationStageTestConfig, -) - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with InitializePipelineOperation (Pipeline has Device Configuration)" -) -class TestIoTHubMQTTTranslationStageRunOpWithInitializePipelineOperationOnDevice( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.mark.it("Derives the MQTT client id, and sets it on the op") - def test_client_id(self, stage, op, pipeline_config): - assert not hasattr(op, "client_id") - stage.run_op(op) - - assert op.client_id == pipeline_config.device_id - - @pytest.mark.it("Derives the MQTT username, and sets it on the op") - @pytest.mark.parametrize( - "cust_product_info", - [ - pytest.param("", id="No custom product info"), - pytest.param("my-product-info", id="With custom product info"), - pytest.param("my$product$info", id="With custom product info (URL encoding required)"), - ], - ) - def test_username(self, stage, op, pipeline_config, cust_product_info): - pipeline_config.product_info = cust_product_info - assert not hasattr(op, "username") - stage.run_op(op) - - expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format( - hostname=pipeline_config.hostname, - client_id=pipeline_config.device_id, - api_version=pkg_constant.IOTHUB_API_VERSION, - user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""), - custom_product_info=urllib.parse.quote(pipeline_config.product_info, safe=""), - ) - assert op.username == expected_username - - @pytest.mark.it( - "Derives the MQTT username, and sets it on the op for digital twin specific scenarios" - ) - @pytest.mark.parametrize( - "digital_twin_product_info", - [ - pytest.param( - pkg_constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1", - id="With custom product info", - ), - pytest.param( - pkg_constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$Sensor;1", - id="With custom product info (URL encoding required)", - ), - ], - ) - def test_username_for_digital_twin(self, stage, op, pipeline_config, digital_twin_product_info): - pipeline_config.product_info = digital_twin_product_info - assert not hasattr(op, "username") - stage.run_op(op) - - expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format( - hostname=pipeline_config.hostname, - client_id=pipeline_config.device_id, - api_version=pkg_constant.DIGITAL_TWIN_API_VERSION, - user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""), - digital_twin_prefix=pkg_constant.DIGITAL_TWIN_QUERY_HEADER, - custom_product_info=urllib.parse.quote(pipeline_config.product_info, safe=""), - ) - assert op.username == expected_username - - @pytest.mark.it( - "ALWAYS uses the pipeline configuration's hostname in the MQTT username and NEVER the gateway_hostname" - ) - def test_hostname_vs_gateway_hostname(self, stage, op, pipeline_config): - # NOTE: this is a sanity check test. There's no reason it should ever be using - # gateway hostname rather than hostname, but these are easily confused fields, so - # this test has been included to catch any possible errors down the road - pipeline_config.hostname = "http://my.hostname" - pipeline_config.gateway_hostname = "http://my.gateway.hostname" - stage.run_op(op) - - assert pipeline_config.hostname in op.username - assert pipeline_config.gateway_hostname not in op.username - - @pytest.mark.it("Sends the op down the pipeline") - def test_sends_down(self, mocker, stage, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with InitializePipelineOperation (Pipeline has Module Configuration)" -) -class TestIoTHubMQTTTranslationStageRunOpWithInitializePipelineOperationOnModule( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def pipeline_config(self, mocker): - # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. - cfg = config.IoTHubPipelineConfig( - hostname="http://my.hostname", - device_id="my_device", - module_id="my_module", - sastoken=mocker.MagicMock(), - ) - return cfg - - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.mark.it("Derives the MQTT client id, and sets it on the op") - def test_client_id(self, stage, op, pipeline_config): - stage.run_op(op) - - expected_client_id = "{device_id}/{module_id}".format( - device_id=pipeline_config.device_id, module_id=pipeline_config.module_id - ) - assert op.client_id == expected_client_id - - @pytest.mark.it("Derives the MQTT username, and sets it on the op") - @pytest.mark.parametrize( - "cust_product_info", - [ - pytest.param("", id="No custom product info"), - pytest.param("my-product-info", id="With custom product info"), - pytest.param("my$product$info", id="With custom product info (URL encoding required)"), - ], - ) - def test_username(self, stage, op, pipeline_config, cust_product_info): - pipeline_config.product_info = cust_product_info - stage.run_op(op) - - expected_client_id = "{device_id}/{module_id}".format( - device_id=pipeline_config.device_id, module_id=pipeline_config.module_id - ) - expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format( - hostname=pipeline_config.hostname, - client_id=expected_client_id, - api_version=pkg_constant.IOTHUB_API_VERSION, - user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""), - custom_product_info=urllib.parse.quote(pipeline_config.product_info, safe=""), - ) - assert op.username == expected_username - - @pytest.mark.it( - "Derives the MQTT username, and sets it on the op for digital twin specific scenarios" - ) - @pytest.mark.parametrize( - "digital_twin_product_info", - [ - pytest.param( - pkg_constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1", - id="With custom product info", - ), - pytest.param( - pkg_constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$Sensor;1", - id="With custom product info (URL encoding required)", - ), - ], - ) - def test_username_for_digital_twin(self, stage, op, pipeline_config, digital_twin_product_info): - pipeline_config.product_info = digital_twin_product_info - assert not hasattr(op, "username") - stage.run_op(op) - - expected_client_id = "{device_id}/{module_id}".format( - device_id=pipeline_config.device_id, module_id=pipeline_config.module_id - ) - expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format( - hostname=pipeline_config.hostname, - client_id=expected_client_id, - api_version=pkg_constant.DIGITAL_TWIN_API_VERSION, - user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""), - digital_twin_prefix=pkg_constant.DIGITAL_TWIN_QUERY_HEADER, - custom_product_info=urllib.parse.quote(pipeline_config.product_info, safe=""), - ) - assert op.username == expected_username - - @pytest.mark.it( - "ALWAYS uses the pipeline configuration's hostname in the MQTT username and NEVER the gateway_hostname" - ) - def test_hostname_vs_gateway_hostname(self, stage, op, pipeline_config): - # NOTE: this is a sanity check test. There's no reason it should ever be using - # gateway hostname rather than hostname, but these are easily confused fields, so - # this test has been included to catch any possible errors down the road - pipeline_config.hostname = "http://my.hostname" - pipeline_config.gateway_hostname = "http://my.gateway.hostname" - stage.run_op(op) - - assert pipeline_config.hostname in op.username - assert pipeline_config.gateway_hostname not in op.username - - @pytest.mark.it("Sends the op down the pipeline") - def test_sends_down(self, mocker, stage, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -# NOTE: All of the following run op tests are tested against a pipeline_config that has been -# configured for a Device Client, not a Module Client. It's worth considering parametrizing -# that fixture so that these tests all run twice - once for a Device, and once for a Module. -# HOWEVER, it's not strictly necessary, due to knowledge of implementation - we are testing that -# the expected values (including module id, which just happens to be set to None when configured -# for a device) are passed where they are expected to be passed. If they're being passed -# correctly, we know it would work no matter what the values are set to. -# -# This also avoids us having module specific tests for device-only features, and vice versa. -# -# In conclusion, while the pipeline_config fixture is technically configured for a device, -# all of the .run_op() tests are written as if it's completely generic. Perhaps this will -# need to change later on. - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with SendD2CMessageOperation" -) -class TestIoTHubMQTTTranslationStageRunOpWithSendD2CMessageOperation( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_iothub.SendD2CMessageOperation( - message=Message("my message"), callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Derives the IoTHub telemetry topic from the device/module details, and encodes the op's message's properties in the resulting topic string" - ) - def test_telemetry_topic(self, mocker, stage, op, pipeline_config, mock_mqtt_topic): - # Although this requirement refers to message properties, we don't actually have to - # parametrize the op to have them, because the entire logic of encoding message properties - # is handled by the mocked out mqtt_topic_iothub library, so whether or not our fixture - # has message properties on the message or not is irrelevant. - stage.run_op(op) - - assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_count == 1 - assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_args == mocker.call( - device_id=pipeline_config.device_id, module_id=pipeline_config.module_id - ) - assert mock_mqtt_topic.encode_message_properties_in_topic.call_count == 1 - assert mock_mqtt_topic.encode_message_properties_in_topic.call_args == mocker.call( - op.message, mock_mqtt_topic.get_telemetry_topic_for_publish.return_value - ) - - @pytest.mark.it( - "Sends a new MQTTPublishOperation down the pipeline with the message data from the original op and the derived topic string" - ) - def test_sends_mqtt_publish_op_down(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert new_op.topic == mock_mqtt_topic.encode_message_properties_in_topic.return_value - assert new_op.payload == op.message.data - - @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with SendOutputMessageOperation" -) -class TestIoTHubMQTTTranslationStageRunOpWithSendOutputMessageOperation( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_iothub.SendOutputMessageOperation( - message=Message("my message"), callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Derives the IoTHub telemetry topic using the device/module details, and encodes the op's message's properties in the resulting topic string" - ) - def test_telemetry_topic(self, mocker, stage, op, pipeline_config, mock_mqtt_topic): - # Although this requirement refers to message properties, we don't actually have to - # parametrize the op to have them, because the entire logic of encoding message properties - # is handled by the mocked out mqtt_topic_iothub library, so whether or not our fixture - # has message properties on the message or not is irrelevant. - stage.run_op(op) - - assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_count == 1 - assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_args == mocker.call( - device_id=pipeline_config.device_id, module_id=pipeline_config.module_id - ) - assert mock_mqtt_topic.encode_message_properties_in_topic.call_count == 1 - assert mock_mqtt_topic.encode_message_properties_in_topic.call_args == mocker.call( - op.message, mock_mqtt_topic.get_telemetry_topic_for_publish.return_value - ) - - @pytest.mark.it( - "Sends a new MQTTPublishOperation down the pipeline with the message data from the original op and the derived topic string" - ) - def test_sends_mqtt_publish_op_down(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert new_op.topic == mock_mqtt_topic.encode_message_properties_in_topic.return_value - assert new_op.payload == op.message.data - - @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with SendMethodResponseOperation" -) -class TestIoTHubMQTTTranslationStageWithSendMethodResponseOperation( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - method_response = MethodResponse( - request_id="fake_request_id", status=200, payload={"some": "json"} - ) - return pipeline_ops_iothub.SendMethodResponseOperation( - method_response=method_response, callback=mocker.MagicMock() - ) - - @pytest.mark.it("Derives the IoTHub telemetry topic using the op's request id and status") - def test_telemetry_topic(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert mock_mqtt_topic.get_method_topic_for_publish.call_count == 1 - assert mock_mqtt_topic.get_method_topic_for_publish.call_args == mocker.call( - op.method_response.request_id, op.method_response.status - ) - - @pytest.mark.it( - "Sends a new MQTTPublishOperation down the pipeline with the original op's payload in JSON string format, and the derived topic string" - ) - @pytest.mark.parametrize( - "payload, expected_string", - [ - pytest.param(None, "null", id="No payload"), - pytest.param({"some": "json"}, '{"some": "json"}', id="Dictionary payload"), - pytest.param("payload", '"payload"', id="String payload"), - ], - ) - def test_sends_mqtt_publish_op_down( - self, mocker, stage, op, mock_mqtt_topic, payload, expected_string - ): - op.method_response.payload = payload - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert new_op.topic == mock_mqtt_topic.get_method_topic_for_publish.return_value - assert new_op.payload == expected_string - - @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with EnableFeatureOperation" -) -class TestIoTHubMQTTTranslationStageRunOpWithEnableFeatureOperation( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker, iothub_pipeline_feature): - return pipeline_ops_base.EnableFeatureOperation( - feature_name=iothub_pipeline_feature, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sends a new MQTTSubscribeOperation down the pipeline, containing the subscription topic string corresponding to the feature being enabled" - ) - def test_mqtt_subscribe_sent_down( - self, op, stage, expected_mqtt_topic_fn, expected_mqtt_topic_fn_call - ): - stage.run_op(op) - - # Topic was derived as expected - assert expected_mqtt_topic_fn.call_count == 1 - assert expected_mqtt_topic_fn.call_args == expected_mqtt_topic_fn_call - - # New op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTSubscribeOperation) - - # New op has the expected topic - assert new_op.topic == expected_mqtt_topic_fn.return_value - - @pytest.mark.it("Completes the original op upon completion of the new MQTTSubscribeOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with DisableFeatureOperation" -) -class TestIoTHubMQTTTranslationStageRunOpWithDisableFeatureOperation( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker, iothub_pipeline_feature): - return pipeline_ops_base.DisableFeatureOperation( - feature_name=iothub_pipeline_feature, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sends a new MQTTUnsubscribeOperation down the pipeline, containing the subscription topic string corresponding to the feature being disabled" - ) - def test_mqtt_unsubscribe_sent_down( - self, op, stage, expected_mqtt_topic_fn, expected_mqtt_topic_fn_call - ): - stage.run_op(op) - - # Topic was derived as expected - assert expected_mqtt_topic_fn.call_count == 1 - assert expected_mqtt_topic_fn.call_args == expected_mqtt_topic_fn_call - - # New op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTUnsubscribeOperation) - - # New op has the expected topic - assert new_op.topic == expected_mqtt_topic_fn.return_value - - @pytest.mark.it("Completes the original op upon completion of the new MQTTUnsubscribeOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe("IoTHubMQTTTranslationStage - .run_op() -- Called with RequestOperation") -class TestIoTHubMQTTTranslationStageWithRequestOperation( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - # Only request operation supported at present by this stage is TWIN. If this changes, - # logic in this whole test class must become more robust - return pipeline_ops_base.RequestOperation( - request_type=constant.TWIN, - method="GET", - resource_location="/", - request_body=" ", - request_id="fake_request_id", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it( - "Derives the IoTHub Twin Request topic using the op's details, if the op is a Twin Request" - ) - def test_twin_request_topic(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert mock_mqtt_topic.get_twin_topic_for_publish.call_count == 1 - assert mock_mqtt_topic.get_twin_topic_for_publish.call_args == mocker.call( - method=op.method, resource_location=op.resource_location, request_id=op.request_id - ) - - @pytest.mark.it( - "Completes the operation with an OperationError failure if the op is any type of request other than a Twin Request" - ) - def test_invalid_op(self, mocker, stage, op): - # Okay, so technically this does'nt prove it does this if it's ANY other type of request, but that's pretty much - # impossible to disprove in a black-box test, because there are infinite possibilities in theory - op.request_type = "Some_other_type" - stage.run_op(op) - assert op.completed - assert isinstance(op.error, OperationError) - - @pytest.mark.it( - "Sends a new MQTTPublishOperation down the pipeline with the original op's request body and the derived topic string" - ) - def test_sends_mqtt_publish_op_down(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert new_op.topic == mock_mqtt_topic.get_twin_topic_for_publish.return_value - assert new_op.payload == op.request_body - - @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with other arbitrary operation" -) -class TestIoTHubMQTTTranslationStageRunOpWithArbitraryOperation( - StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (C2D topic string)" -) -class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventC2DTopic( - StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def event(self, pipeline_config): - # topic device id MATCHES THE PIPELINE CONFIG - topic = "devices/{device_id}/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2F{device_id}%2Fmessages%2Fdevicebound".format( - device_id=pipeline_config.device_id - ) - return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload="some payload") - - @pytest.mark.it( - "Creates a Message with the event's payload, and applies any message properties included in the topic" - ) - def test_message(self, event, stage, mock_mqtt_topic): - stage.handle_pipeline_event(event) - - # Message properties were extracted from the topic - # NOTE that because this is mocked, we don't need to test various topics with various properties - assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 - assert mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][0] == event.topic - message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] - assert isinstance(message, Message) - # The message contains the event's payload - assert message.data == event.payload - - @pytest.mark.it( - "Sends a new C2DMessageEvent up the pipeline, containing the newly created Message" - ) - def test_c2d_message_event(self, event, stage, mock_mqtt_topic): - stage.handle_pipeline_event(event) - - # C2DMessageEvent was sent up the pipeline - assert stage.send_event_up.call_count == 1 - new_event = stage.send_event_up.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.C2DMessageEvent) - # The C2DMessageEvent contains the same Message that was created from the topic details - assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 - message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] - assert new_event.message is message - - @pytest.mark.it( - "Sends the original event up the pipeline instead, if the device id in the topic string does not match the client details" - ) - def test_nonmatching_device_id(self, mocker, event, stage): - stage.nucleus.pipeline_configuration.device_id = "different_device_id" - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Input Message topic string)" -) -class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventInputTopic( - StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def pipeline_config(self, mocker): - cfg = config.IoTHubPipelineConfig( - hostname="fake_hostname", - device_id="my_device", - module_id="my_module", - sastoken=mocker.MagicMock(), - ) - return cfg - - @pytest.fixture - def input_name(self): - return "some_input" - - @pytest.fixture - def event(self, pipeline_config, input_name): - # topic device id MATCHES THE PIPELINE CONFIG - topic = "devices/{device_id}/modules/{module_id}/inputs/{input_name}/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2F{device_id}%2Fmodules%2F{module_id}%2Finputs%2F{input_name}".format( - device_id=pipeline_config.device_id, - module_id=pipeline_config.module_id, - input_name=input_name, - ) - return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload="some payload") - - @pytest.mark.it( - "Creates a Message with the event's payload, and applies any message properties included in the topic" - ) - def test_message(self, event, stage, mock_mqtt_topic): - stage.handle_pipeline_event(event) - - # Message properties were extracted from the topic - # NOTE that because this is mocked, we don't need to test various topics with various properties - assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 - assert mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][0] == event.topic - message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] - assert isinstance(message, Message) - # The message contains the event's payload - assert message.data == event.payload - - @pytest.mark.it( - "Sends a new InputMessageEvent up the pipeline, containing the newly created Message with the input name extracted from the topic" - ) - def test_input_message_event(self, event, stage, mock_mqtt_topic, input_name): - stage.handle_pipeline_event(event) - - # InputMessageEvent was sent up the pipeline - assert stage.send_event_up.call_count == 1 - new_event = stage.send_event_up.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.InputMessageEvent) - # The InputMessageEvent contains the same Message that was created from the topic details - assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 - message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] - assert new_event.message is message - # The Message contains the same input name from the topic - assert new_event.message.input_name == input_name - - @pytest.mark.it( - "Sends the original event up the pipeline instead, if the the topic string does not match the client details" - ) - @pytest.mark.parametrize( - "alt_device_id, alt_module_id", - [ - pytest.param("different_device_id", None, id="Non-matching device id"), - pytest.param(None, "different_module_id", id="Non-matching module id"), - pytest.param( - "different_device_id", - "different_module_id", - id="Non-matching device id AND module id", - ), - ], - ) - def test_nonmatching_ids(self, mocker, event, stage, alt_device_id, alt_module_id): - if alt_device_id: - stage.nucleus.pipeline_configuration.device_id = alt_device_id - if alt_module_id: - stage.nucleus.pipeline_configuration.module_id = alt_module_id - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Method topic string)" -) -class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventMethodTopic( - StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def method_name(self): - return "some_method" - - @pytest.fixture - def rid(self): - return "1" - - @pytest.fixture - def event(self, method_name, rid): - topic = "$iothub/methods/POST/{method_name}/?$rid={rid}".format( - method_name=method_name, rid=rid - ) - return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=topic, payload=b'{"some": "json"}' - ) - - @pytest.mark.it( - "Sends a MethodRequestEvent up the pipeline with a MethodRequest containing values extracted from the event's topic" - ) - def test_method_request(self, event, stage, method_name, rid): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - new_event = stage.send_event_up.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.MethodRequestEvent) - assert isinstance(new_event.method_request, MethodRequest) - assert new_event.method_request.name == method_name - assert new_event.method_request.request_id == rid - # This is expanded on in in the next test - assert new_event.method_request.payload == json.loads(event.payload.decode("utf-8")) - - @pytest.mark.it( - "Derives the MethodRequestEvent's payload by converting the original event's payload from bytes into a JSON object" - ) - @pytest.mark.parametrize( - "original_payload, derived_payload", - [ - pytest.param(b'{"some": "payload"}', {"some": "payload"}, id="Dictionary JSON"), - pytest.param(b'"payload"', "payload", id="String JSON"), - pytest.param(b"1234", 1234, id="Int JSON"), - pytest.param(b"null", None, id="None JSON"), - ], - ) - def test_json_payload(self, event, stage, original_payload, derived_payload): - event.payload = original_payload - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - new_event = stage.send_event_up.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.MethodRequestEvent) - assert isinstance(new_event.method_request, MethodRequest) - - assert new_event.method_request.payload == derived_payload - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Twin Response topic string)" -) -class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventTwinResponseTopic( - StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def status(self): - return 200 - - @pytest.fixture - def rid(self): - return "d9d7ce4d-3be9-498b-abde-913b81b880e5" - - @pytest.fixture - def event(self, status, rid): - topic = "$iothub/twin/res/{status}/?$rid={rid}".format(status=status, rid=rid) - return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some_payload") - - @pytest.mark.it( - "Sends a ResponseEvent up the pipeline containing the original event's payload, and values extracted from the topic string" - ) - def test_response_event(self, event, stage, status, rid): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - new_event = stage.send_event_up.call_args[0][0] - assert isinstance(new_event, pipeline_events_base.ResponseEvent) - assert new_event.status_code == status - assert new_event.request_id == rid - assert new_event.response_body == event.payload - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Twin Desired Properties Patch topic string)" -) -class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventTwinDesiredPropertiesPatchTopic( - StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def event(self): - topic = "$iothub/twin/PATCH/properties/desired/?$version=1" - # payload will be overwritten in relevant tests - return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=topic, payload=b'{"some": "payload"}' - ) - - @pytest.mark.it( - "Sends a TwinDesiredPropertiesPatchEvent up the pipeline, containing the original event's payload formatted as a JSON-object" - ) - @pytest.mark.parametrize( - "original_payload, derived_payload", - [ - pytest.param(b'{"some": "payload"}', {"some": "payload"}, id="Dictionary JSON"), - pytest.param(b'"payload"', "payload", id="String JSON"), - pytest.param(b"1234", 1234, id="Int JSON"), - pytest.param(b"null", None, id="None JSON"), - ], - ) - def test_twin_patch_event(self, event, stage, original_payload, derived_payload): - event.payload = original_payload - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - new_event = stage.send_event_up.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.TwinDesiredPropertiesPatchEvent) - assert new_event.patch == derived_payload - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Unrecognized topic string)" -) -class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventUnknownTopicString( - StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def event(self): - topic = "not a real topic" - return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some payload") - - @pytest.mark.it("Sends the event up the pipeline") - def test_sends_up(self, event, stage): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args[0][0] == event - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with other arbitrary event" -) -class TestIoTHubMQTTTranslationStageHandlePipelineEventWithArbitraryEvent( - StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig -): - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it("Sends the event up the pipeline") - def test_sends_up(self, event, stage): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args[0][0] == event diff --git a/tests/unit/iothub/shared_client_tests.py b/tests/unit/iothub/shared_client_tests.py deleted file mode 100644 index 488722656..000000000 --- a/tests/unit/iothub/shared_client_tests.py +++ /dev/null @@ -1,2059 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains tests that are shared between sync/async clients -i.e. tests for things defined in abstract clients""" - -import pytest -import logging -import os -import io -import time -import urllib -from azure.iot.device.common import auth, handle_exceptions -from azure.iot.device.common.auth import sastoken as st -from azure.iot.device.common.auth import connection_string as cs -from azure.iot.device.iothub.pipeline import IoTHubPipelineConfig -from azure.iot.device.common.pipeline.config import DEFAULT_KEEPALIVE -from azure.iot.device.iothub.abstract_clients import ( - RECEIVE_TYPE_NONE_SET, - RECEIVE_TYPE_HANDLER, - RECEIVE_TYPE_API, -) -from azure.iot.device.iothub import edge_hsm -from azure.iot.device.iothub import client_event -from azure.iot.device import ProxyOptions -from azure.iot.device import exceptions as client_exceptions - -logging.basicConfig(level=logging.DEBUG) - - -#################### -# HELPER FUNCTIONS # -#################### - - -def token_parser(token_str): - """helper function that parses a token string for individual values""" - token_map = {} - kv_string = token_str.split(" ")[1] - kv_pairs = kv_string.split("&") - for kv in kv_pairs: - t = kv.split("=") - token_map[t[0]] = t[1] - return token_map - - -################################ -# SHARED DEVICE + MODULE TESTS # -################################ - - -class SharedIoTHubClientInstantiationTests(object): - @pytest.mark.it( - "Stores the MQTTPipeline from the 'mqtt_pipeline' parameter in the '_mqtt_pipeline' attribute" - ) - def test_mqtt_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline is mqtt_pipeline - - @pytest.mark.it( - "Stores the HTTPPipeline from the 'http_pipeline' parameter in the '_http_pipeline' attribute" - ) - def test_sets_http_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._http_pipeline is http_pipeline - - @pytest.mark.it("Sets on_connected handler in the MQTTPipeline") - def test_sets_on_connected_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_connected is not None - assert client._mqtt_pipeline.on_connected == client._on_connected - - @pytest.mark.it("Sets on_disconnected handler in the MQTTPipeline") - def test_sets_on_disconnected_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_disconnected is not None - assert client._mqtt_pipeline.on_disconnected == client._on_disconnected - - @pytest.mark.it("Sets on_new_sastoken_required handler in the MQTTPipeline") - def test_sets_on_new_sastoken_required_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_new_sastoken_required is not None - assert client._mqtt_pipeline.on_new_sastoken_required == client._on_new_sastoken_required - - @pytest.mark.it("Sets on_background_exception handler in the MQTTPipeline") - def test_sets_on_background_exception_handler(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_background_exception is not None - assert client._mqtt_pipeline.on_background_exception == client._on_background_exception - - @pytest.mark.it("Sets on_method_request_received handler in the MQTTPipeline") - def test_sets_on_method_request_received_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_method_request_received is not None - assert ( - client._mqtt_pipeline.on_method_request_received - == client._inbox_manager.route_method_request - ) - - @pytest.mark.it("Sets on_twin_patch_received handler in the MQTTPipeline") - def test_sets_on_twin_patch_received_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_twin_patch_received is not None - assert ( - client._mqtt_pipeline.on_twin_patch_received == client._inbox_manager.route_twin_patch - ) - - @pytest.mark.it("Sets the Receive Mode/Type for the client as yet-unchosen") - def test_initial_receive_mode(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._receive_type == RECEIVE_TYPE_NONE_SET - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubClientCreateMethodUserOptionTests(object): - @pytest.fixture - def option_test_required_patching(self, mocker): - """Override this fixture in a subclass if unique patching is required""" - pass - - @pytest.mark.it( - "Sets the 'product_info' user option parameter on the PipelineConfig, if provided" - ) - def test_product_info_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - - product_info = "MyProductInfo" - client_create_method(*create_method_args, product_info=product_info) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.product_info == product_info - - @pytest.mark.it( - "Sets the 'ensure_desired_properties' user option parameter on the PipelineConfig, if provided" - ) - def test_ensure_desired_properties_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - - client_create_method(*create_method_args, ensure_desired_properties=True) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.ensure_desired_properties is True - - @pytest.mark.it( - "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" - ) - def test_websockets_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - - client_create_method(*create_method_args, websockets=True) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.websockets - - # TODO: Show that input in the wrong format is formatted to the correct one. This test exists - # in the IoTHubPipelineConfig object already, but we do not currently show that this is felt - # from the API level. - @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") - def test_cipher_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" - client_create_method(*create_method_args, cipher=cipher) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.cipher == cipher - - @pytest.mark.it( - "Sets the 'server_verification_cert' user option parameter on the PipelineConfig, if provided" - ) - def test_server_verification_cert_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - server_verification_cert = "fake_server_verification_cert" - client_create_method(*create_method_args, server_verification_cert=server_verification_cert) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.server_verification_cert == server_verification_cert - - @pytest.mark.it( - "Sets the 'gateway_hostname' user option parameter on the PipelineConfig, if provided" - ) - def test_gateway_hostname_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - gateway_hostname = "my.gateway" - client_create_method(*create_method_args, gateway_hostname=gateway_hostname) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.gateway_hostname == gateway_hostname - - @pytest.mark.it( - "Sets the 'proxy_options' user option parameter on the PipelineConfig, if provided" - ) - def test_proxy_options( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - proxy_options = ProxyOptions(proxy_type="HTTP", proxy_addr="127.0.0.1", proxy_port=8888) - client_create_method(*create_method_args, proxy_options=proxy_options) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.proxy_options is proxy_options - - @pytest.mark.it( - "Sets the 'keep_alive' user option parameter on the PipelineConfig, if provided" - ) - def test_keep_alive_options( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - keepalive_value = 60 - client_create_method(*create_method_args, keep_alive=keepalive_value) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.keep_alive == keepalive_value - - @pytest.mark.it( - "Sets the 'auto_connect' user option parameter on the PipelineConfig, if provided" - ) - def test_auto_connect_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - auto_connect_value = False - client_create_method(*create_method_args, auto_connect=auto_connect_value) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.auto_connect == auto_connect_value - - @pytest.mark.it( - "Sets the 'connection_retry' user option parameter on the PipelineConfig, if provided" - ) - def test_connection_retry_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - connection_retry_value = False - client_create_method(*create_method_args, connection_retry=connection_retry_value) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.connection_retry == connection_retry_value - - @pytest.mark.it( - "Sets the 'connection_retry_interval' user option parameter on the PipelineConfig, if provided" - ) - def test_connection_retry_interval_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - connection_retry_interval_value = 17 - client_create_method( - *create_method_args, connection_retry_interval=connection_retry_interval_value - ) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - assert config == mock_http_pipeline_init.call_args[0][0] - - assert config.connection_retry_interval == connection_retry_interval_value - - @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") - def test_invalid_option( - self, option_test_required_patching, client_create_method, create_method_args - ): - with pytest.raises(TypeError): - client_create_method(*create_method_args, invalid_option="some_value") - - # NOTE: If any further tests need to override this test, it's time to restructure. - @pytest.mark.it("Sets default user options if none are provided") - def test_default_options( - self, - mocker, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - client_create_method(*create_method_args) - - # Both pipelines use the same IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - - # Pipeline Config has default options set that were not user-specified - assert config.product_info == "" - assert config.websockets is False - assert config.cipher == "" - assert config.proxy_options is None - assert config.server_verification_cert is None - assert config.gateway_hostname is None - assert config.keep_alive == DEFAULT_KEEPALIVE - assert config.auto_connect is True - assert config.connection_retry is True - assert config.connection_retry_interval == 10 - assert config.ensure_desired_properties is True - - -# TODO: consider splitting this test class up into device/module specific test classes to avoid -# the conditional logic in some tests -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubClientCreateFromConnectionStringTests( - SharedIoTHubClientCreateMethodUserOptionTests -): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_connection_string - - @pytest.fixture - def create_method_args(self, connection_string): - """Provides the specific create method args for use in universal tests""" - return [connection_string] - - @pytest.mark.it( - "Raises a TypeError if the 'gateway_hostname' user option parameter is provided" - ) - def test_gateway_hostname_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - # Override to test that gateway_hostname CANNOT be provided in Edge scenarios - - with pytest.raises(TypeError): - client_create_method(*create_method_args, gateway_hostname="my.gateway.device") - - @pytest.mark.it("Sets default user options if none are provided") - def test_default_options( - self, - mocker, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - # Override to remove an assertion about gateway_hostname - - client_create_method(*create_method_args) - - # Both pipelines use the same IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - - # Pipeline Config has default options set that were not user-specified - assert config.product_info == "" - assert config.websockets is False - assert config.cipher == "" - assert config.proxy_options is None - assert config.server_verification_cert is None - assert config.keep_alive == DEFAULT_KEEPALIVE - assert config.auto_connect is True - assert config.connection_retry is True - assert config.connection_retry_interval == 10 - - @pytest.mark.it( - "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values in the provided connection string" - ) - def test_sastoken(self, mocker, client_class, connection_string): - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - cs_obj = cs.ConnectionString(connection_string) - - custom_ttl = 1000 - client_class.create_from_connection_string(connection_string, sastoken_ttl=custom_ttl) - - # Determine expected URI based on class under test - if client_class.__name__ == "IoTHubDeviceClient": - expected_uri = "{hostname}/devices/{device_id}".format( - hostname=cs_obj[cs.HOST_NAME], device_id=cs_obj[cs.DEVICE_ID] - ) - else: - expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=cs_obj[cs.HOST_NAME], - device_id=cs_obj[cs.DEVICE_ID], - module_id=cs_obj[cs.MODULE_ID], - ) - - # SymmetricKeySigningMechanism created using the connection string's SharedAccessKey - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) - - # Token was created with a SymmetricKeySigningMechanism, the expected URI, and custom ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=custom_ttl - ) - - @pytest.mark.it( - "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" - ) - def test_sastoken_default(self, mocker, client_class, connection_string): - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - cs_obj = cs.ConnectionString(connection_string) - - client_class.create_from_connection_string(connection_string) - - # Determine expected URI based on class under test - if client_class.__name__ == "IoTHubDeviceClient": - expected_uri = "{hostname}/devices/{device_id}".format( - hostname=cs_obj[cs.HOST_NAME], device_id=cs_obj[cs.DEVICE_ID] - ) - else: - expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=cs_obj[cs.HOST_NAME], - device_id=cs_obj[cs.DEVICE_ID], - module_id=cs_obj[cs.MODULE_ID], - ) - - # SymmetricKeySigningMechanism created using the connection string's SharedAccessKey - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) - - # Token was created with a SymmetricKeySigningMechanism, the expected URI, and default ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=3600 - ) - - @pytest.mark.it( - "Creates MQTT and HTTP Pipelines with an IoTHubPipelineConfig object containing the SasToken and values from the connection string" - ) - def test_pipeline_config( - self, - mocker, - client_class, - connection_string, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - cs_obj = cs.ConnectionString(connection_string) - - client_class.create_from_connection_string(connection_string) - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == cs_obj[cs.DEVICE_ID] - assert config.hostname == cs_obj[cs.HOST_NAME] - assert config.sastoken is sastoken_mock.return_value - if client_class.__name__ == "IoTHubModuleClient": - assert config.module_id == cs_obj[cs.MODULE_ID] - assert config.blob_upload is False - assert config.method_invoke is False - else: - assert config.module_id is None - assert config.blob_upload is True - assert config.method_invoke is False - if cs_obj.get(cs.GATEWAY_HOST_NAME): - assert config.gateway_hostname == cs_obj[cs.GATEWAY_HOST_NAME] - else: - assert config.gateway_hostname is None - - @pytest.mark.it( - "Returns an instance of an IoTHub client using the created MQTT and HTTP pipelines" - ) - def test_client_returned( - self, - mocker, - client_class, - connection_string, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - client = client_class.create_from_connection_string(connection_string) - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises ValueError when given an invalid connection string") - @pytest.mark.parametrize( - "bad_cs", - [ - pytest.param("not-a-connection-string", id="Garbage string"), - pytest.param( - "HostName=value.domain.net;DeviceId=my_device;SharedAccessKey=Invalid", - id="Shared Access Key invalid", - ), - pytest.param( - "HostName=value.domain.net;WrongValue=Invalid;SharedAccessKey=Zm9vYmFy", - id="Contains extraneous data", - ), - pytest.param("HostName=value.domain.net;DeviceId=my_device", id="Incomplete"), - pytest.param( - "HostName=value.domain.net;DeviceId=my_device;x509=True", - id="X509 Connection String", - ), - ], - ) - def test_raises_value_error_on_bad_connection_string(self, client_class, bad_cs): - with pytest.raises(ValueError): - client_class.create_from_connection_string(bad_cs) - - @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") - def test_raises_value_error_on_sastoken_failure(self, mocker, client_class, connection_string): - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - client_class.create_from_connection_string(connection_string) - assert e_info.value.__cause__ is token_err - - -class SharedIoTHubClientPROPERTYHandlerTests(object): - @pytest.mark.it("Can have its value set and retrieved") - def test_read_write(self, client, handler, handler_name): - assert getattr(client, handler_name) is None - setattr(client, handler_name, handler) - assert getattr(client, handler_name) is handler - - @pytest.mark.it("Reflects the value of the handler manager property of the same name") - def test_set_on_handler_manager(self, client, handler, handler_name): - assert getattr(client, handler_name) is None - assert getattr(client, handler_name) is getattr(client._handler_manager, handler_name) - setattr(client, handler_name, handler) - assert getattr(client, handler_name) is handler - assert getattr(client, handler_name) is getattr(client._handler_manager, handler_name) - - -class SharedIoTHubClientPROPERTYReceiverHandlerTests(SharedIoTHubClientPROPERTYHandlerTests): - @pytest.mark.it("Can have its value set and retrieved") - def test_read_write(self, client, handler, handler_name): - assert getattr(client, handler_name) is None - setattr(client, handler_name, handler) - assert getattr(client, handler_name) is handler - - @pytest.mark.it("Reflects the value of the handler manager property of the same name") - def test_set_on_handler_manager(self, client, handler, handler_name): - assert getattr(client, handler_name) is None - assert getattr(client, handler_name) is getattr(client._handler_manager, handler_name) - setattr(client, handler_name, handler) - assert getattr(client, handler_name) is handler - assert getattr(client, handler_name) is getattr(client._handler_manager, handler_name) - - @pytest.mark.it( - "Implicitly enables the corresponding feature if not already enabled, when a handler value is set" - ) - def test_enables_feature_only_if_not_already_enabled( - self, mocker, client, handler, handler_name, feature_name, mqtt_pipeline - ): - # Feature will appear disabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - # Set handler - setattr(client, handler_name, handler) - # Feature was enabled - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == feature_name - - mqtt_pipeline.enable_feature.reset_mock() - - # Feature will appear already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True - # Set handler - setattr(client, handler_name, handler) - # Feature was not enabled again - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it( - "Implicitly disables the corresponding feature if not already disabled, when handler value is set back to None" - ) - def test_disables_feature_only_if_not_already_disabled( - self, mocker, client, handler_name, feature_name, mqtt_pipeline - ): - # Feature will appear enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True - # Set handler to None - setattr(client, handler_name, None) - # Feature was disabled - assert mqtt_pipeline.disable_feature.call_count == 1 - assert mqtt_pipeline.disable_feature.call_args[0][0] == feature_name - - mqtt_pipeline.disable_feature.reset_mock() - - # Feature will appear already disabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - # Set handler to None - setattr(client, handler_name, None) - # Feature was not disabled again - assert mqtt_pipeline.disable_feature.call_count == 0 - - @pytest.mark.it( - "Locks the client to Handler Receive Mode if the receive mode has not yet been set" - ) - def test_receive_mode_not_set(self, client, handler, handler_name): - assert client._receive_type is RECEIVE_TYPE_NONE_SET - setattr(client, handler_name, handler) - assert client._receive_type is RECEIVE_TYPE_HANDLER - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to Handler Receive Mode" - ) - def test_receive_mode_set_handler(self, client, handler, handler_name): - client._receive_type = RECEIVE_TYPE_HANDLER - setattr(client, handler_name, handler) - assert client._receive_type is RECEIVE_TYPE_HANDLER - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has already been set to API Receive Mode" - ) - def test_receive_mode_set_api(self, client, handler, handler_name, mqtt_pipeline): - client._receive_type = RECEIVE_TYPE_API - # Error was raised - with pytest.raises(client_exceptions.ClientError): - setattr(client, handler_name, handler) - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - - -# NOTE: If more properties are added, this class should become a general purpose properties test class -class SharedIoTHubClientPROPERTYConnectedTests(object): - @pytest.mark.it("Cannot be changed") - def test_read_only(self, client): - with pytest.raises(AttributeError): - client.connected = not client.connected - - @pytest.mark.it("Reflects the value of the root stage property of the same name") - def test_reflects_pipeline_property(self, client, mqtt_pipeline): - mqtt_pipeline.connected = True - assert client.connected - mqtt_pipeline.connected = False - assert not client.connected - - -class SharedIoTHubClientOCCURRENCEConnectTests(object): - @pytest.mark.it( - "Adds a CONNECTION_STATE_CHANGE ClientEvent to the ClientEvent Inbox if the HandlerManager is currently handling ClientEvents" - ) - def test_handler_manager_handling_events(self, client, mocker): - # NOTE: It's hard to mock a read-only property (.handling_client_events), so we're breaking - # the rule about black-boxing other modules to simulate what we want. Sorry. - client._handler_manager._client_event_runner = mocker.MagicMock() # fake thread - assert client._handler_manager.handling_client_events is True - client_event_inbox = client._inbox_manager.get_client_event_inbox() - inbox_put_spy = mocker.spy(client_event_inbox, "put") - assert client_event_inbox.empty() - - client._on_connected() - - # ClientEvent was added - assert not client_event_inbox.empty() - assert inbox_put_spy.call_count == 1 - event = inbox_put_spy.call_args[0][0] - assert isinstance(event, client_event.ClientEvent) - assert event.name == client_event.CONNECTION_STATE_CHANGE - assert event.args_for_user == () - - @pytest.mark.it( - "Does not add any ClientEvents to the ClientEvent Inbox if the HandlerManager is not currently handling ClientEvents" - ) - def test_handler_manager_not_handling_events(self, client): - assert client._handler_manager.handling_client_events is False - client_event_inbox = client._inbox_manager.get_client_event_inbox() - assert client_event_inbox.empty() - - client._on_connected() - - # Inbox is still empty - assert client_event_inbox.empty() - - @pytest.mark.it("Ensures that the HandlerManager is running") - @pytest.mark.parametrize( - "handling_client_events", - [True, False], - ids=["Manager Handling ClientEvents", "Manager Not Handling ClientEvents"], - ) - def test_ensure_handler_manager_running_on_connect( - self, client, mocker, handling_client_events - ): - if handling_client_events: - # NOTE: It's hard to mock a read-only property (.handling_client_events), so we're breaking - # the rule about black-boxing other modules to simulate what we want. Sorry. - client._handler_manager._client_event_runner = mocker.MagicMock() # fake thread - assert client._handler_manager.handling_client_events is handling_client_events - ensure_running_spy = mocker.spy(client._handler_manager, "ensure_running") - client._on_connected() - assert ensure_running_spy.call_count == 1 - - -class SharedIoTHubClientOCCURRENCEDisconnectTests(object): - @pytest.mark.it( - "Adds a CONNECTION_STATE_CHANGE ClientEvent to the ClientEvent Inbox if the HandlerManager is currently handling ClientEvents" - ) - def test_handler_manager_handling_event(self, client, mocker): - # NOTE: It's hard to mock a read-only property (.handling_client_events), so we're breaking - # the rule about black-boxing other modules to simulate what we want. Sorry. - client._handler_manager._client_event_runner = mocker.MagicMock() # fake thread - assert client._handler_manager.handling_client_events is True - client_event_inbox = client._inbox_manager.get_client_event_inbox() - inbox_put_spy = mocker.spy(client_event_inbox, "put") - assert client_event_inbox.empty() - - client._on_disconnected() - - assert not client_event_inbox.empty() - assert inbox_put_spy.call_count == 1 - event = inbox_put_spy.call_args[0][0] - assert isinstance(event, client_event.ClientEvent) - assert event.name == client_event.CONNECTION_STATE_CHANGE - assert event.args_for_user == () - - @pytest.mark.it( - "Does not add any ClientEvents to the ClientEvent Inbox if the HandlerManager is not currently handling ClientEvents" - ) - def test_handler_manager_not_handling_events(self, client): - assert client._handler_manager.handling_client_events is False - client_event_inbox = client._inbox_manager.get_client_event_inbox() - assert client_event_inbox.empty() - - client._on_disconnected() - - # Inbox is still empty - assert client_event_inbox.empty() - - @pytest.mark.it("Clears all pending MethodRequests") - @pytest.mark.parametrize( - "handling_client_events", - [True, False], - ids=["Manager Handling ClientEvents", "Manager Not Handling ClientEvents"], - ) - def test_state_change_handler_clears_method_request_inboxes_on_disconnect( - self, client, mocker, handling_client_events - ): - if handling_client_events: - # NOTE: It's hard to mock a read-only property (.handling_client_events), so we're breaking - # the rule about black-boxing other modules to simulate what we want. Sorry. - client._handler_manager._client_event_runner = mocker.MagicMock() # fake thread - assert client._handler_manager.handling_client_events is handling_client_events - clear_method_request_spy = mocker.spy(client._inbox_manager, "clear_all_method_requests") - client._on_disconnected() - assert clear_method_request_spy.call_count == 1 - - -class SharedIoTHubClientOCCURRENCENewSastokenRequired(object): - @pytest.mark.it( - "Adds a NEW_SASTOKEN_REQUIRED ClientEvent to the ClientEvent Inbox if the HandlerManager is currently handling ClientEvents" - ) - def test_handler_manager_handling_events(self, client, mocker): - # NOTE: It's hard to mock a read-only property (.handling_client_events), so we're breaking - # the rule about black-boxing other modules to simulate what we want. Sorry. - client._handler_manager._client_event_runner = mocker.MagicMock() # fake thread - assert client._handler_manager.handling_client_events is True - client_event_inbox = client._inbox_manager.get_client_event_inbox() - inbox_put_spy = mocker.spy(client_event_inbox, "put") - assert client_event_inbox.empty() - - client._on_new_sastoken_required() - - assert not client_event_inbox.empty() - assert inbox_put_spy.call_count == 1 - event = inbox_put_spy.call_args[0][0] - assert isinstance(event, client_event.ClientEvent) - assert event.name == client_event.NEW_SASTOKEN_REQUIRED - assert event.args_for_user == () - - @pytest.mark.it( - "Does not add any ClientEvents to the ClientEvent Inbox if the HandlerManager is not currently handling ClientEvents" - ) - def test_handler_manager_not_handling_events(self, client): - assert client._handler_manager.handling_client_events is False - client_event_inbox = client._inbox_manager.get_client_event_inbox() - assert client_event_inbox.empty() - - client._on_new_sastoken_required() - - # Inbox still empty - assert client_event_inbox.empty() - - -class SharedIoTHubClientOCCURRENCEBackgroundException(object): - @pytest.mark.it("Sends the exception to the handle_exceptions module") - @pytest.mark.parametrize( - "handling_client_events", - [True, False], - ids=["Manager Handling ClientEvents", "Manager Not Handling ClientEvents"], - ) - def test_handle_exceptions_module( - self, client, mocker, arbitrary_exception, handling_client_events - ): - if handling_client_events: - # NOTE: It's hard to mock a read-only property (.handling_client_events), so we're breaking - # the rule about black-boxing other modules to simulate what we want. Sorry. - client._handler_manager._client_event_runner = mocker.MagicMock() # fake thread - assert client._handler_manager.handling_client_events is handling_client_events - background_exc_spy = mocker.spy(handle_exceptions, "handle_background_exception") - - client._on_background_exception(arbitrary_exception) - - assert background_exc_spy.call_count == 1 - assert background_exc_spy.call_args == mocker.call(arbitrary_exception) - - @pytest.mark.it( - "Adds a BACKGROUND_EXCEPTION ClientEvent (containing the exception) to the ClientEvent Inbox if the HandlerManager is currently handling ClientEvents" - ) - def test_handler_manager_handling_events(self, client, mocker, arbitrary_exception): - # NOTE: It's hard to mock a read-only property (.handling_client_events), so we're breaking - # the rule about black-boxing other modules to simulate what we want. Sorry. - client._handler_manager._client_event_runner = mocker.MagicMock() # fake thread - assert client._handler_manager.handling_client_events is True - client_event_inbox = client._inbox_manager.get_client_event_inbox() - inbox_put_spy = mocker.spy(client_event_inbox, "put") - assert client_event_inbox.empty() - - client._on_background_exception(arbitrary_exception) - - assert not client_event_inbox.empty() - assert inbox_put_spy.call_count == 1 - event = inbox_put_spy.call_args[0][0] - assert isinstance(event, client_event.ClientEvent) - assert event.name == client_event.BACKGROUND_EXCEPTION - assert event.args_for_user == (arbitrary_exception,) - - @pytest.mark.it( - "Does not add any ClientEvents to the ClientEvent Inbox if the HandlerManager is not currently handling ClientEvents" - ) - def test_handler_manager_not_handling_events(self, client, arbitrary_exception): - assert client._handler_manager.handling_client_events is False - client_event_inbox = client._inbox_manager.get_client_event_inbox() - assert client_event_inbox.empty() - - client._on_background_exception(arbitrary_exception) - - # Inbox is still empty - assert client_event_inbox.empty() - - -############################## -# SHARED DEVICE CLIENT TESTS # -############################## - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubDeviceClientCreateFromSastokenTests( - SharedIoTHubClientCreateMethodUserOptionTests -): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_sastoken - - @pytest.fixture - def create_method_args(self, sas_token_string): - """Provides the specific create method args for use in universal tests""" - return [sas_token_string] - - @pytest.mark.it( - "Creates a NonRenewableSasToken from the SAS token string provided in parameters" - ) - def test_sastoken(self, mocker, client_class, sas_token_string): - real_sastoken = st.NonRenewableSasToken(sas_token_string) - sastoken_mock = mocker.patch.object(st, "NonRenewableSasToken") - sastoken_mock.return_value = real_sastoken - - client_class.create_from_sastoken(sastoken=sas_token_string) - - # NonRenewableSasToken created from sastoken string - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(sas_token_string) - - @pytest.mark.it( - "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values extracted from the SasToken" - ) - def test_pipeline_config( - self, - mocker, - client_class, - sas_token_string, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - real_sastoken = st.NonRenewableSasToken(sas_token_string) - sastoken_mock = mocker.patch.object(st, "NonRenewableSasToken") - sastoken_mock.return_value = real_sastoken - client_class.create_from_sastoken(sas_token_string) - - token_uri_pieces = real_sastoken.resource_uri.split("/") - expected_hostname = token_uri_pieces[0] - expected_device_id = token_uri_pieces[2] - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == expected_device_id - assert config.module_id is None - assert config.hostname == expected_hostname - assert config.gateway_hostname is None - assert config.sastoken is sastoken_mock.return_value - assert config.blob_upload is True - assert config.method_invoke is False - - @pytest.mark.it( - "Returns an instance of an IoTHubDeviceClient using the created MQTT and HTTP pipelines" - ) - def test_client_returned( - self, - mocker, - client_class, - sas_token_string, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - client = client_class.create_from_sastoken(sastoken=sas_token_string) - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises ValueError if NonRenewableSasToken creation results in failure") - def test_raises_value_error_on_sastoken_failure(self, sas_token_string, mocker, client_class): - # NOTE: specific inputs that could cause this are tested in the sastoken test module - sastoken_mock = mocker.patch.object(st, "NonRenewableSasToken") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - client_class.create_from_sastoken(sastoken=sas_token_string) - assert e_info.value.__cause__ is token_err - - @pytest.mark.it("Raises ValueError if the provided SAS token string has an invalid URI") - @pytest.mark.parametrize( - "invalid_token_uri", - [ - pytest.param("some.hostname/devices", id="Too short"), - pytest.param("some.hostname/devices/my_device/somethingElse", id="Too long"), - pytest.param( - "some.hostname/not-devices/device_id", id="Incorrectly formatted device notation" - ), - pytest.param("some.hostname/devices/my_device/modules/my_module", id="Module URI"), - ], - ) - def test_raises_value_error_invalid_uri(self, mocker, client_class, invalid_token_uri): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote(invalid_token_uri, safe=""), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() + 3600), - ) - - with pytest.raises(ValueError): - client_class.create_from_sastoken(sastoken=sastoken_str) - - @pytest.mark.it("Raises ValueError if the provided SAS token string has already expired") - def test_expired_token(self, mocker, client_class): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote("some.hostname/devices/my_device", safe=""), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() - 3600), # expired - ) - - with pytest.raises(ValueError): - client_class.create_from_sastoken(sastoken=sastoken_str) - - @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") - def test_sastoken_ttl(self, client_class, sas_token_string): - with pytest.raises(TypeError): - client_class.create_from_sastoken(sastoken=sas_token_string, sastoken_ttl=1000) - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubDeviceClientCreateFromSymmetricKeyTests( - SharedIoTHubClientCreateMethodUserOptionTests -): - hostname = "durmstranginstitute.farend" - device_id = "MySnitch" - symmetric_key = "Zm9vYmFy" - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_symmetric_key - - @pytest.fixture - def create_method_args(self): - """Provides the specific create method args for use in universal tests""" - return [self.symmetric_key, self.hostname, self.device_id] - - @pytest.mark.it( - "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values provided in parameters" - ) - def test_sastoken(self, mocker, client_class): - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - expected_uri = "{hostname}/devices/{device_id}".format( - hostname=self.hostname, device_id=self.device_id - ) - - custom_ttl = 1000 - client_class.create_from_symmetric_key( - symmetric_key=self.symmetric_key, - hostname=self.hostname, - device_id=self.device_id, - sastoken_ttl=custom_ttl, - ) - - # SymmetricKeySigningMechanism created using the provided symmetric key - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=self.symmetric_key) - - # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the custom ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=custom_ttl - ) - - @pytest.mark.it( - "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" - ) - def test_sastoken_default(self, mocker, client_class): - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - expected_uri = "{hostname}/devices/{device_id}".format( - hostname=self.hostname, device_id=self.device_id - ) - - client_class.create_from_symmetric_key( - symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id - ) - - # SymmetricKeySigningMechanism created using the provided symmetric key - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=self.symmetric_key) - - # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the default ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=3600 - ) - - @pytest.mark.it( - "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values provided in parameters" - ) - def test_pipeline_config( - self, mocker, client_class, mock_mqtt_pipeline_init, mock_http_pipeline_init - ): - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - - client_class.create_from_symmetric_key( - symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id - ) - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == self.device_id - assert config.hostname == self.hostname - assert config.gateway_hostname is None - assert config.sastoken is sastoken_mock.return_value - assert config.blob_upload is True - assert config.method_invoke is False - - @pytest.mark.it( - "Returns an instance of an IoTHubDeviceClient using the created MQTT and HTTP pipelines" - ) - def test_client_returned( - self, mocker, client_class, mock_mqtt_pipeline_init, mock_http_pipeline_init - ): - client = client_class.create_from_symmetric_key( - symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id - ) - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") - def test_raises_value_error_on_sastoken_failure(self, mocker, client_class): - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - client_class.create_from_symmetric_key( - symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id - ) - assert e_info.value.__cause__ is token_err - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubDeviceClientCreateFromX509CertificateTests( - SharedIoTHubClientCreateMethodUserOptionTests -): - hostname = "durmstranginstitute.farend" - device_id = "MySnitch" - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - """Provides the specific create method args for use in universal tests""" - return [x509, self.hostname, self.device_id] - - @pytest.mark.it( - "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the X509 and other values provided in parameters" - ) - def test_pipeline_config( - self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init - ): - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] == mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == self.device_id - assert config.hostname == self.hostname - assert config.gateway_hostname is None - assert config.x509 is x509 - assert config.blob_upload is True - assert config.method_invoke is False - - @pytest.mark.it( - "Returns an instance of an IoTHubDeviceclient using the created MQTT and HTTP pipelines" - ) - def test_client_returned( - self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init - ): - client = client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") - def test_sastoken_ttl(self, client_class, x509): - with pytest.raises(TypeError): - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, sastoken_ttl=1000 - ) - - -############################## -# SHARED MODULE CLIENT TESTS # -############################## - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubModuleClientCreateFromSastokenTests( - SharedIoTHubClientCreateMethodUserOptionTests -): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_sastoken - - @pytest.fixture - def create_method_args(self, sas_token_string): - """Provides the specific create method args for use in universal tests""" - return [sas_token_string] - - @pytest.mark.it( - "Creates a NonRenewableSasToken from the SAS token string provided in parameters" - ) - def test_sastoken(self, mocker, client_class, sas_token_string): - real_sastoken = st.NonRenewableSasToken(sas_token_string) - sastoken_mock = mocker.patch.object(st, "NonRenewableSasToken") - sastoken_mock.return_value = real_sastoken - - client_class.create_from_sastoken(sastoken=sas_token_string) - - # NonRenewableSasToken created from sastoken string - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(sas_token_string) - - @pytest.mark.it( - "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values extracted from the SasToken" - ) - def test_pipeline_config( - self, - mocker, - client_class, - sas_token_string, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - real_sastoken = st.NonRenewableSasToken(sas_token_string) - sastoken_mock = mocker.patch.object(st, "NonRenewableSasToken") - sastoken_mock.return_value = real_sastoken - client_class.create_from_sastoken(sastoken=sas_token_string) - - token_uri_pieces = real_sastoken.resource_uri.split("/") - expected_hostname = token_uri_pieces[0] - expected_device_id = token_uri_pieces[2] - expected_module_id = token_uri_pieces[4] - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == expected_device_id - assert config.module_id == expected_module_id - assert config.hostname == expected_hostname - assert config.gateway_hostname is None - assert config.sastoken is sastoken_mock.return_value - assert config.blob_upload is False - assert config.method_invoke is False - - @pytest.mark.it( - "Returns an instance of an IoTHubModuleClient using the created MQTT and HTTP pipelines" - ) - def test_client_returned( - self, - mocker, - client_class, - sas_token_string, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - client = client_class.create_from_sastoken(sastoken=sas_token_string) - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises ValueError if NonRenewableSasToken creation results in failure") - def test_raises_value_error_on_sastoken_failure(self, mocker, client_class, sas_token_string): - # NOTE: specific inputs that could cause this are tested in the sastoken test module - sastoken_mock = mocker.patch.object(st, "NonRenewableSasToken") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - client_class.create_from_sastoken(sastoken=sas_token_string) - assert e_info.value.__cause__ is token_err - - @pytest.mark.it("Raises ValueError if the provided SAS token string has an invalid URI") - @pytest.mark.parametrize( - "invalid_token_uri", - [ - pytest.param("some.hostname/devices/my_device/modules/", id="Too short"), - pytest.param( - "some.hostname/devices/my_device/modules/my_module/somethingElse", id="Too long" - ), - pytest.param( - "some.hostname/not-devices/device_id/modules/module_id", - id="Incorrectly formatted device notation", - ), - pytest.param( - "some.hostname/devices/device_id/not-modules/module_id", - id="Incorrectly formatted module notation", - ), - pytest.param("some.hostname/devices/my_device/", id="Device URI"), - ], - ) - def test_raises_value_error_invalid_uri(self, mocker, client_class, invalid_token_uri): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote(invalid_token_uri, safe=""), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() + 3600), - ) - - with pytest.raises(ValueError): - client_class.create_from_sastoken(sastoken=sastoken_str) - - @pytest.mark.it("Raises ValueError if the provided SAS token string has already expired") - def test_expired_token(self, mocker, client_class): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote( - "some.hostname/devices/my_device/modules/my_module", safe="" - ), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() - 3600), # expired - ) - - with pytest.raises(ValueError): - client_class.create_from_sastoken(sastoken=sastoken_str) - - @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") - def test_sastoken_ttl(self, client_class, sas_token_string): - with pytest.raises(TypeError): - client_class.create_from_sastoken(sastoken=sas_token_string, sastoken_ttl=1000) - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubModuleClientCreateFromX509CertificateTests( - SharedIoTHubClientCreateMethodUserOptionTests -): - hostname = "durmstranginstitute.farend" - device_id = "MySnitch" - module_id = "Charms" - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - """Provides the specific create method args for use in universal tests""" - return [x509, self.hostname, self.device_id, self.module_id] - - @pytest.mark.it( - "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the X509 and other values provided in parameters" - ) - def test_pipeline_config( - self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init - ): - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] == mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == self.device_id - assert config.hostname == self.hostname - assert config.gateway_hostname is None - assert config.x509 is x509 - assert config.blob_upload is False - assert config.method_invoke is False - - @pytest.mark.it( - "Returns an instance of an IoTHubDeviceclient using the created MQTT and HTTP pipelines" - ) - def test_client_returned( - self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init - ): - client = client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") - def test_sastoken_ttl(self, client_class, x509): - with pytest.raises(TypeError): - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, sastoken_ttl=1000 - ) - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests( - SharedIoTHubClientCreateMethodUserOptionTests -): - """This class inherits the user option tests shared by all create method APIs, and overrides - tests in order to accommodate unique requirements for the .create_from_edge_environment() method. - - Because .create_from_edge_environment() tests are spread across multiple test units - (i.e. test classes), these overrides are done in this class, which is then inherited by all - .create_from_edge_environment() test units below. - """ - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_edge_environment - - @pytest.fixture - def create_method_args(self): - """Provides the specific create method args for use in universal tests""" - return [] - - @pytest.mark.it( - "Raises a TypeError if the 'server_verification_cert' user option parameter is provided" - ) - def test_server_verification_cert_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - # Override to test that server_verification_cert CANNOT be provided in Edge scenarios - - with pytest.raises(TypeError): - client_create_method( - *create_method_args, server_verification_cert="fake_server_verification_cert" - ) - - @pytest.mark.it( - "Raises a TypeError if the 'gateway_hostname' user option parameter is provided" - ) - def test_gateway_hostname_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - # Override to test that gateway_hostname CANNOT be provided in Edge scenarios - - with pytest.raises(TypeError): - client_create_method(*create_method_args, gateway_hostname="my.gateway.device") - - @pytest.mark.it("Sets default user options if none are provided") - def test_default_options( - self, - mocker, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - # Override so that can avoid the check on server_verification_cert being None - # as in Edge scenarios, it is not None - - client_create_method(*create_method_args) - - # Both pipelines use the same IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - config = mock_mqtt_pipeline_init.call_args[0][0] - assert isinstance(config, IoTHubPipelineConfig) - - # Pipeline Config has default options that were not specified - assert config.product_info == "" - assert config.websockets is False - assert config.cipher == "" - assert config.proxy_options is None - assert config.keep_alive == DEFAULT_KEEPALIVE - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests( - SharedIoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests -): - @pytest.fixture - def option_test_required_patching(self, mocker, mock_edge_hsm, edge_container_environment): - """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - - @pytest.mark.it( - "Creates a SasToken that uses an IoTEdgeHsm, from the values extracted from the Edge environment and the user-provided TTL" - ) - def test_sastoken(self, mocker, client_class, mock_edge_hsm, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - - expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - ) - - custom_ttl = 1000 - client_class.create_from_edge_environment(sastoken_ttl=custom_ttl) - - # IoTEdgeHsm created using the extracted values - assert mock_edge_hsm.call_count == 1 - assert mock_edge_hsm.call_args == mocker.call( - module_id=edge_container_environment["IOTEDGE_MODULEID"], - generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - # SasToken created with the IoTEdgeHsm, the expected URI and the custom ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, mock_edge_hsm.return_value, ttl=custom_ttl - ) - - @pytest.mark.it( - "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" - ) - def test_sastoken_default( - self, mocker, client_class, mock_edge_hsm, edge_container_environment - ): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - - expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - ) - - client_class.create_from_edge_environment() - - # IoTEdgeHsm created using the extracted values - assert mock_edge_hsm.call_count == 1 - assert mock_edge_hsm.call_args == mocker.call( - module_id=edge_container_environment["IOTEDGE_MODULEID"], - generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - # SasToken created with the IoTEdgeHsm, the expected URI, and the default ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, mock_edge_hsm.return_value, ttl=3600 - ) - - @pytest.mark.it( - "Uses an IoTEdgeHsm as the SasToken signing mechanism even if any Edge local debug environment variables may also be present" - ) - def test_hybrid_env( - self, - mocker, - client_class, - mock_edge_hsm, - edge_container_environment, - edge_local_debug_environment, - ): - hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) - mocker.patch.dict(os.environ, hybrid_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - mock_sksm = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - - client_class.create_from_edge_environment() - - assert mock_sksm.call_count == 0 # we did NOT use SK signing mechanism - assert mock_edge_hsm.call_count == 1 # instead, we still used edge hsm - assert mock_edge_hsm.call_args == mocker.call( - module_id=edge_container_environment["IOTEDGE_MODULEID"], - generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - mocker.ANY, mock_edge_hsm.return_value, ttl=3600 - ) - - @pytest.mark.it( - "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values extracted from the Edge environment" - ) - def test_pipeline_config( - self, - mocker, - client_class, - mock_edge_hsm, - edge_container_environment, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - - client_class.create_from_edge_environment() - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == edge_container_environment["IOTEDGE_DEVICEID"] - assert config.module_id == edge_container_environment["IOTEDGE_MODULEID"] - assert config.hostname == edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"] - assert config.gateway_hostname == edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"] - assert config.sastoken is sastoken_mock.return_value - assert ( - config.server_verification_cert - == mock_edge_hsm.return_value.get_certificate.return_value - ) - assert config.method_invoke is True - assert config.blob_upload is False - - @pytest.mark.it( - "Returns an instance of an IoTHubModuleClient using the created MQTT and HTTP pipelines" - ) - def test_client_returns( - self, - mocker, - client_class, - mock_edge_hsm, - edge_container_environment, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - - client = client_class.create_from_edge_environment() - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises OSError if the environment is missing required variables") - @pytest.mark.parametrize( - "missing_env_var", - [ - "IOTEDGE_MODULEID", - "IOTEDGE_DEVICEID", - "IOTEDGE_IOTHUBHOSTNAME", - "IOTEDGE_GATEWAYHOSTNAME", - "IOTEDGE_APIVERSION", - "IOTEDGE_MODULEGENERATIONID", - "IOTEDGE_WORKLOADURI", - ], - ) - def test_bad_environment( - self, mocker, client_class, edge_container_environment, missing_env_var - ): - # Remove a variable from the fixture - del edge_container_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - - with pytest.raises(OSError): - client_class.create_from_edge_environment() - - @pytest.mark.it( - "Raises OSError if there is an error retrieving the server verification certificate from Edge with the IoTEdgeHsm" - ) - def test_bad_edge_auth(self, mocker, client_class, edge_container_environment, mock_edge_hsm): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - my_edge_error = edge_hsm.IoTEdgeError() - mock_edge_hsm.return_value.get_certificate.side_effect = my_edge_error - - with pytest.raises(OSError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is my_edge_error - - @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") - def test_raises_value_error_on_sastoken_failure( - self, mocker, client_class, edge_container_environment, mock_edge_hsm - ): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is token_err - - -@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") -class SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests( - SharedIoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests -): - @pytest.fixture - def option_test_required_patching(self, mocker, mock_open, edge_local_debug_environment): - """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - @pytest.fixture - def mock_open(self, mocker): - return mocker.patch.object(io, "open") - - @pytest.mark.it( - "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values in the connection string extracted from the Edge local debug environment, as well as the user-provided TTL" - ) - def test_sastoken(self, mocker, client_class, mock_open, edge_local_debug_environment): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - cs_obj = cs.ConnectionString(edge_local_debug_environment["EdgeHubConnectionString"]) - expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=cs_obj[cs.HOST_NAME], - device_id=cs_obj[cs.DEVICE_ID], - module_id=cs_obj[cs.MODULE_ID], - ) - - custom_ttl = 1000 - client_class.create_from_edge_environment(sastoken_ttl=custom_ttl) - - # SymmetricKeySigningMechanism created using the connection string's Shared Access Key - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) - - # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the custom ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=custom_ttl - ) - - @pytest.mark.it( - "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" - ) - def test_sastoken_default(self, mocker, client_class, mock_open, edge_local_debug_environment): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - cs_obj = cs.ConnectionString(edge_local_debug_environment["EdgeHubConnectionString"]) - expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=cs_obj[cs.HOST_NAME], - device_id=cs_obj[cs.DEVICE_ID], - module_id=cs_obj[cs.MODULE_ID], - ) - - client_class.create_from_edge_environment() - - # SymmetricKeySigningMechanism created using the connection string's Shared Access Key - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) - - # SasToken created with the SymmetricKeySigningMechanism, the expected URI and default ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=3600 - ) - - @pytest.mark.it( - "Only uses Edge local debug variables if no Edge container variables are present in the environment" - ) - def test_auth_provider_and_pipeline_hybrid_env( - self, - mocker, - client_class, - edge_container_environment, - edge_local_debug_environment, - mock_open, - mock_edge_hsm, - ): - # This test verifies that the presence of edge container environment variables means the - # code will follow the edge container environment creation path (using the IoTEdgeHsm) - # even if edge local debug variables are present. - hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) - mocker.patch.dict(os.environ, hybrid_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - - client_class.create_from_edge_environment() - - assert sksm_mock.call_count == 0 # we did NOT use SK signing mechanism - assert mock_edge_hsm.call_count == 1 # instead, we still used edge HSM - assert mock_edge_hsm.call_args == mocker.call( - module_id=edge_container_environment["IOTEDGE_MODULEID"], - generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - mocker.ANY, mock_edge_hsm.return_value, ttl=3600 - ) - - @pytest.mark.it( - "Extracts the server verification certificate from the file indicated by the filepath extracted from the Edge local debug environment" - ) - def test_open_ca_cert(self, mocker, client_class, edge_local_debug_environment, mock_open): - mock_file_handle = mock_open.return_value.__enter__.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - client_class.create_from_edge_environment() - - assert mock_open.call_count == 1 - assert mock_open.call_args == mocker.call( - edge_local_debug_environment["EdgeModuleCACertificateFile"], mode="r" - ) - assert mock_file_handle.read.call_count == 1 - assert mock_file_handle.read.call_args == mocker.call() - - @pytest.mark.it( - "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values extracted from the Edge local debug environment" - ) - def test_pipeline_config( - self, - mocker, - client_class, - mock_open, - edge_local_debug_environment, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - mock_file_handle = mock_open.return_value.__enter__.return_value - ca_cert_file_contents = "some cert" - mock_file_handle.read.return_value = ca_cert_file_contents - - cs_obj = cs.ConnectionString(edge_local_debug_environment["EdgeHubConnectionString"]) - - client_class.create_from_edge_environment() - - # Verify pipelines created with an IoTHubPipelineConfig - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] - assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) - - # Verify the IoTHubPipelineConfig is constructed as expected - config = mock_mqtt_pipeline_init.call_args[0][0] - assert config.device_id == cs_obj[cs.DEVICE_ID] - assert config.module_id == cs_obj[cs.MODULE_ID] - assert config.hostname == cs_obj[cs.HOST_NAME] - assert config.gateway_hostname == cs_obj[cs.GATEWAY_HOST_NAME] - assert config.sastoken is sastoken_mock.return_value - assert config.server_verification_cert == ca_cert_file_contents - assert config.method_invoke is True - assert config.blob_upload is False - - @pytest.mark.it( - "Returns an instance of an IoTHub client using the created MQTT and HTTP pipelines" - ) - def test_client_returned( - self, - mocker, - client_class, - mock_open, - edge_local_debug_environment, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - client = client_class.create_from_edge_environment() - - assert isinstance(client, client_class) - assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value - assert client._http_pipeline is mock_http_pipeline_init.return_value - - @pytest.mark.it("Raises OSError if the environment is missing required variables") - @pytest.mark.parametrize( - "missing_env_var", ["EdgeHubConnectionString", "EdgeModuleCACertificateFile"] - ) - def test_bad_environment( - self, mocker, client_class, edge_local_debug_environment, missing_env_var, mock_open - ): - # Remove a variable from the fixture - del edge_local_debug_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - with pytest.raises(OSError): - client_class.create_from_edge_environment() - - @pytest.mark.it( - "Raises ValueError if the connection string in the EdgeHubConnectionString environment variable is invalid" - ) - @pytest.mark.parametrize( - "bad_cs", - [ - pytest.param("not-a-connection-string", id="Garbage string"), - pytest.param( - "HostName=value.domain.net;DeviceId=my_device;ModuleId=my_module;SharedAccessKey=Invalid", - id="Shared Access Key invalid", - ), - pytest.param( - "HostName=value.domain.net;WrongValue=Invalid;SharedAccessKey=Zm9vYmFy", - id="Contains extraneous data", - ), - pytest.param("HostName=value.domain.net;DeviceId=my_device", id="Incomplete"), - ], - ) - def test_bad_connection_string( - self, mocker, client_class, edge_local_debug_environment, bad_cs, mock_open - ): - edge_local_debug_environment["EdgeHubConnectionString"] = bad_cs - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - with pytest.raises(ValueError): - client_class.create_from_edge_environment() - - @pytest.mark.it( - "Raises FileNotFoundError if the filepath in the EdgeModuleCACertificateFile environment variable is invalid" - ) - def test_bad_filepath(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - my_fnf_error = FileNotFoundError() - mock_open.side_effect = my_fnf_error - with pytest.raises(FileNotFoundError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value is my_fnf_error - - @pytest.mark.it( - "Raises ValueError if the file referenced by the filepath in the EdgeModuleCACertificateFile environment variable cannot be opened" - ) - def test_bad_file_io(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - my_os_error = OSError() - mock_open.side_effect = my_os_error - with pytest.raises(ValueError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is my_os_error - - @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") - def test_raises_value_error_on_sastoken_failure( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is token_err - - -#################### -# HELPER FUNCTIONS # -#################### -def merge_dicts(d1, d2): - d3 = d1.copy() - d3.update(d2) - return d3 diff --git a/tests/unit/iothub/test_client_event.py b/tests/unit/iothub/test_client_event.py deleted file mode 100644 index 47e55aa62..000000000 --- a/tests/unit/iothub/test_client_event.py +++ /dev/null @@ -1,43 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.iothub.client_event import ( - ClientEvent, - CONNECTION_STATE_CHANGE, - NEW_SASTOKEN_REQUIRED, - BACKGROUND_EXCEPTION, -) - -logging.basicConfig(level=logging.DEBUG) - -all_client_events = [CONNECTION_STATE_CHANGE, NEW_SASTOKEN_REQUIRED, BACKGROUND_EXCEPTION] - - -@pytest.mark.describe("ClientEvent") -class TestClientEvent(object): - @pytest.mark.it("Instantiates with the 'name' attribute set to the provided 'name' parameter") - @pytest.mark.parametrize("name", all_client_events) - def test_name(self, name): - event = ClientEvent(name) - assert event.name == name - - @pytest.mark.it( - "Instantiates with the 'args_for_user' attribute set to a variable-length list of all other provided parameters" - ) - @pytest.mark.parametrize( - "user_args", - [ - pytest.param((), id="0 args"), - pytest.param(("1",), id="1 arg"), - pytest.param(("1", "2"), id="2 args"), - pytest.param(("1", "2", "3", "4", "5"), id="5 args"), - ], - ) - def test_args_for_user(self, user_args): - event = ClientEvent("some_event", *user_args) - assert event.args_for_user == user_args diff --git a/tests/unit/iothub/test_inbox_manager.py b/tests/unit/iothub/test_inbox_manager.py deleted file mode 100644 index eaa824ea9..000000000 --- a/tests/unit/iothub/test_inbox_manager.py +++ /dev/null @@ -1,432 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.iothub.inbox_manager import InboxManager -from azure.iot.device.iothub.models import MethodRequest -from azure.iot.device.iothub.aio.async_inbox import AsyncClientInbox -from azure.iot.device.iothub.sync_inbox import SyncClientInbox - -logging.basicConfig(level=logging.DEBUG) - - -@pytest.fixture( - params=[AsyncClientInbox, SyncClientInbox], - ids=["Configured with AsyncClientInboxes", "Configured with SyncClientInboxes"], -) -def inbox_type(request): - return request.param - - -@pytest.fixture -def manager(inbox_type): - return InboxManager(inbox_type=inbox_type) - - -@pytest.fixture -def manager_unified_mode(inbox_type): - m = InboxManager(inbox_type=inbox_type) - m.use_unified_msg_mode = True - return m - - -@pytest.fixture -def method_request(): - return MethodRequest(request_id="1", name="some_method", payload="{'key': 'value'}") - - -@pytest.mark.describe("InboxManager") -class TestInboxManager(object): - @pytest.mark.it("Instantiates with 'Unified Message Mode' turned off") - def test_Unified_mode(self, manager): - assert manager.use_unified_msg_mode is False - - @pytest.mark.it("Instantiates with an empty unified message inbox") - def test_instantiates_with_empty_unified_msg_inbox(self, manager): - assert manager.unified_message_inbox.empty() - - @pytest.mark.it("Instantiates with an empty generic method request inbox") - def test_instantiates_with_empty_generic_method_inbox(self, manager): - assert manager.generic_method_request_inbox.empty() - - @pytest.mark.it("Instantiates with an empty twin patch inbox") - def test_instantiates_with_empty_twin_patch_inbox(self, manager): - assert manager.twin_patch_inbox.empty() - - @pytest.mark.it("Instantiates with an empty client event inbox") - def test_instantiates_with_empty_client_event_inbox(self, manager): - assert manager.client_event_inbox.empty() - - @pytest.mark.it("Instantiates with an empty C2D inbox") - def test_instantiates_with_empty_c2d_inbox(self, manager): - assert manager.c2d_message_inbox.empty() - - @pytest.mark.it("Instantiates with no input message inboxes") - def test_instantiates_with_no_input_inboxes(self, manager): - assert manager.input_message_inboxes == {} - - @pytest.mark.it("Instantiates with no specific method request inboxes") - def test_instantiates_with_no_specific_method_inboxes(self, manager): - assert manager.named_method_request_inboxes == {} - - -@pytest.mark.describe("InboxManager - .get_c2d_message_inbox()") -class TestInboxManagerGetC2DMessageInbox(object): - @pytest.mark.it("Returns an inbox") - def test_returns_inbox(self, manager, inbox_type): - c2d_inbox = manager.get_c2d_message_inbox() - assert isinstance(c2d_inbox, inbox_type) - - @pytest.mark.it("Returns the same inbox when called multiple times") - def test_called_multiple_times_returns_same_inbox(self, manager): - c2d_inbox_ref1 = manager.get_c2d_message_inbox() - c2d_inbox_ref2 = manager.get_c2d_message_inbox() - assert c2d_inbox_ref1 is c2d_inbox_ref2 - - -@pytest.mark.describe("InboxManager - .get_unified_message_inbox()") -class TestInboxManagerGetUnifiedMessageInbox(object): - @pytest.mark.it("Returns an inbox") - def test_returns_inbox(self, manager, inbox_type): - um_inbox = manager.get_unified_message_inbox() - assert isinstance(um_inbox, inbox_type) - - @pytest.mark.it("Returns the same inbox when called multiple times") - def test_called_multiple_times_returns_same_inbox(self, manager): - um_inbox_ref1 = manager.get_unified_message_inbox() - um_inbox_ref2 = manager.get_unified_message_inbox() - assert um_inbox_ref1 is um_inbox_ref2 - - -@pytest.mark.describe("InboxManager - .get_input_message_inbox()") -class TestInboxManagerGetInputMessageInbox(object): - @pytest.mark.it("Returns an inbox") - def test_get_input_message_inbox_returns_inbox(self, manager, inbox_type): - input_name = "some_input" - input_inbox = manager.get_input_message_inbox(input_name) - assert isinstance(input_inbox, inbox_type) - - @pytest.mark.it("Returns the same inbox when called multiple times with the same input name") - def test_get_input_message_inbox_called_multiple_times_with_same_input_name_returns_same_inbox( - self, manager - ): - input_name = "some_input" - input_inbox_ref1 = manager.get_input_message_inbox(input_name) - input_inbox_ref2 = manager.get_input_message_inbox(input_name) - assert input_inbox_ref1 is input_inbox_ref2 - - @pytest.mark.it( - "Returns a different inbox when called multiple times with a different input name" - ) - def test_get_input_message_inbox_called_multiple_times_with_different_input_name_returns_different_inbox( - self, manager - ): - input_inbox1 = manager.get_input_message_inbox("some_input") - input_inbox2 = manager.get_input_message_inbox("some_other_input") - assert input_inbox1 is not input_inbox2 - - @pytest.mark.it( - "Implicitly creates an input message inbox, that persists, when a new input name is provided" - ) - def test_input_message_inboxes_persist_in_manager_after_creation(self, manager): - assert manager.input_message_inboxes == {} # empty dict - no inboxes - input1 = "some_input" - input_inbox1 = manager.get_input_message_inbox(input1) - assert input1 in manager.input_message_inboxes.keys() - assert input_inbox1 in manager.input_message_inboxes.values() - - -@pytest.mark.describe("InboxManager - .get_method_request_inbox()") -class TestInboxManagerGetMethodRequestInbox(object): - @pytest.mark.it("Returns an inbox") - @pytest.mark.parametrize( - "method_name", - [ - pytest.param("some_method", id="Called with a method name"), - pytest.param(None, id="Called with no method name"), - ], - ) - def test_get_method_request_inbox_returns_inbox(self, manager, method_name, inbox_type): - method_request_inbox = manager.get_method_request_inbox(method_name) - assert isinstance(method_request_inbox, inbox_type) - - @pytest.mark.it("Returns the same inbox when called multiple times with the same method name") - @pytest.mark.parametrize( - "method_name", - [ - pytest.param("some_method", id="Called with a method name"), - pytest.param(None, id="Called with no method name"), - ], - ) - def test_get_method_request_inbox_called_multiple_times_with_same_method_name_returns_same_inbox( - self, manager, method_name - ): - message_request_inbox1 = manager.get_method_request_inbox(method_name) - message_request_inbox2 = manager.get_method_request_inbox(method_name) - assert message_request_inbox1 is message_request_inbox2 - - @pytest.mark.it( - "Returns a different inbox when called multiple times with a different method name" - ) - def test_get_method_request_inbox_called_multiple_times_with_different_method_name_returns_different_inbox( - self, manager - ): - message_request_inbox1 = manager.get_method_request_inbox("some_method") - message_request_inbox2 = manager.get_method_request_inbox("some_other_method") - message_request_inbox3 = manager.get_method_request_inbox() - assert message_request_inbox1 is not message_request_inbox2 - assert message_request_inbox1 is not message_request_inbox3 - assert message_request_inbox2 is not message_request_inbox3 - - @pytest.mark.it( - "Implicitly creates an method request inbox, that persists, when a new method name is provided" - ) - def test_input_message_inboxes_persist_in_manager_after_creation(self, manager): - assert manager.named_method_request_inboxes == {} # empty dict - no inboxes - method_name = "some_method" - method_inbox = manager.get_method_request_inbox(method_name) - assert method_name in manager.named_method_request_inboxes.keys() - assert method_inbox in manager.named_method_request_inboxes.values() - - -@pytest.mark.describe("InboxManager - .get_twin_patch_inbox()") -class TestInboxManagerGetTwinPatchInbox(object): - @pytest.mark.it("Returns an inbox") - def test_returns_inbox(self, manager, inbox_type): - tp_inbox = manager.get_twin_patch_inbox() - assert isinstance(tp_inbox, inbox_type) - - @pytest.mark.it("Returns the same inbox when called multiple times") - def test_called_multiple_times_returns_same_inbox(self, manager): - tp_inbox_ref1 = manager.get_twin_patch_inbox() - tp_inbox_ref2 = manager.get_twin_patch_inbox() - assert tp_inbox_ref1 is tp_inbox_ref2 - - -@pytest.mark.describe("InboxManager - .get_client_event_inbox()") -class TestInboxManagerGetClientEventInbox(object): - @pytest.mark.it("Returns an inbox") - def test_returns_inbox(self, manager, inbox_type): - ce_inbox = manager.get_client_event_inbox() - assert isinstance(ce_inbox, inbox_type) - - @pytest.mark.it("Returns the same inbox when called multiple times") - def test_called_multiple_times_returns_same_inbox(self, manager): - ce_inbox_ref1 = manager.get_client_event_inbox() - ce_inbox_ref2 = manager.get_client_event_inbox() - assert ce_inbox_ref1 is ce_inbox_ref2 - - -@pytest.mark.describe("InboxManager - .clear_all_method_requests()") -class TestInboxManagerClearAllMethodRequests(object): - @pytest.mark.it("Clears the generic method request inbox") - def test_clears_generic_method_request_inbox(self, manager): - generic_method_request_inbox = manager.get_method_request_inbox() - assert generic_method_request_inbox.empty() - manager.route_method_request(MethodRequest("id", "unrecognized_method_name", "payload")) - assert not generic_method_request_inbox.empty() - - manager.clear_all_method_requests() - assert generic_method_request_inbox.empty() - - @pytest.mark.it("Clears all specific method request inboxes") - def test_clear_all_method_requests_clears_named_method_request_inboxes(self, manager): - method_request_inbox1 = manager.get_method_request_inbox("some_method") - method_request_inbox2 = manager.get_method_request_inbox("some_other_method") - assert method_request_inbox1.empty() - assert method_request_inbox2.empty() - manager.route_method_request(MethodRequest("id1", "some_method", "payload")) - manager.route_method_request(MethodRequest("id2", "some_other_method", "payload")) - assert not method_request_inbox1.empty() - assert not method_request_inbox2.empty() - - manager.clear_all_method_requests() - assert method_request_inbox1.empty() - assert method_request_inbox2.empty() - - -@pytest.mark.describe("InboxManager - .route_c2d_message() -- Standard Mode") -class TestInboxManagerRouteC2DMessage(object): - @pytest.mark.it("Adds Message to the C2D message inbox") - def test_adds_message_to_c2d_message_inbox(self, manager, message): - c2d_inbox = manager.get_c2d_message_inbox() - assert c2d_inbox.empty() - delivered = manager.route_c2d_message(message) - assert delivered - assert not c2d_inbox.empty() - assert message in c2d_inbox - - @pytest.mark.it("Does NOT add Message to the unified message inbox") - def test_DOES_NOT_add_message_to_unified_inbox(self, manager, message): - um_inbox = manager.get_unified_message_inbox() - assert um_inbox.empty() - delivered = manager.route_c2d_message(message) - assert delivered - assert um_inbox.empty() - - -@pytest.mark.describe("InboxManager - .route_c2d_message() -- Unified Message Mode") -class TestInboxManagerRouteC2DMessageUnified(object): - @pytest.mark.it("Adds Message to the unified message inbox") - def test_adds_message_to_unified_message_inbox(self, manager_unified_mode, message): - manager = manager_unified_mode - um_inbox = manager.get_unified_message_inbox() - assert um_inbox.empty() - delivered = manager.route_c2d_message(message) - assert delivered - assert not um_inbox.empty() - assert message in um_inbox - - @pytest.mark.it("Does NOT add Message to the C2D message inbox") - def test_DOES_NOT_add_message_to_c2d_inbox(self, manager_unified_mode, message): - manager = manager_unified_mode - c2d_inbox = manager.get_c2d_message_inbox() - assert c2d_inbox.empty() - delivered = manager.route_c2d_message(message) - assert delivered - assert c2d_inbox.empty() - - -@pytest.mark.describe("InboxManager - .route_input_message() -- Standard Mode") -class TestInboxManagerRouteInputMessage(object): - @pytest.mark.it( - "Adds Message to the input message inbox that corresponds to the input name, if it exists" - ) - def test_adds_message_to_input_message_inbox(self, manager, message): - input_name = "some_input" - message.input_name = input_name - input_inbox = manager.get_input_message_inbox(input_name) - assert input_inbox.empty() - delivered = manager.route_input_message(message) - assert delivered - assert not input_inbox.empty() - assert message in input_inbox - - @pytest.mark.it( - "Drops a Message if the input name does not correspond to an input message inbox" - ) - def test_drops_message_to_unknown_input(self, manager, message): - message.input_name = "not_a_real_input" - delivered = manager.route_input_message(message) - assert not delivered - - @pytest.mark.it("Does NOT add Message to the unified message inbox") - def test_DOES_NOT_add_message_to_unified_inbox(self, manager, message): - message.input_name = "some_input" - manager.get_input_message_inbox( - message.input_name - ) # create a input inbox to be delivered to - um_inbox = manager.get_unified_message_inbox() - assert um_inbox.empty() - delivered = manager.route_input_message(message) - assert delivered - assert um_inbox.empty() - - -@pytest.mark.describe("InboxManager - .route_input_message() -- Unified Message Mode") -class TestInboxManagerRouteInputMessageUnified(object): - @pytest.mark.it("Adds Message to the unified message inbox") - def test_adds_message_to_unified_message_inbox(self, manager_unified_mode, message): - manager = manager_unified_mode - message.input_name = "some_input" - um_inbox = manager.get_unified_message_inbox() - assert um_inbox.empty() - delivered = manager.route_input_message(message) - assert delivered - assert not um_inbox.empty() - assert message in um_inbox - - @pytest.mark.it("Does NOT add Message to a specific input message inbox, even if one exists") - def test_DOES_NOT_add_message_to_input_inbox(self, manager_unified_mode, message): - manager = manager_unified_mode - input_name = "some_input" - message.input_name = input_name - input_inbox = manager.get_input_message_inbox(input_name) - assert input_inbox.empty() - delivered = manager.route_input_message(message) - assert delivered - assert input_inbox.empty() - - -@pytest.mark.describe("InboxManager - .route_method_request()") -class TestInboxManagerRouteMethodRequest(object): - @pytest.mark.it( - "Adds MethodRequest to the method request inbox corresponding to the method name, if it exists" - ) - def test_calling_with_known_method_adds_method_to_named_method_inbox( - self, manager, method_request - ): - # Establish an inbox with the corresponding method name - named_method_inbox = manager.get_method_request_inbox(method_request.name) - generic_method_inbox = manager.get_method_request_inbox() - assert named_method_inbox.empty() - assert generic_method_inbox.empty() - - delivered = manager.route_method_request(method_request) - assert delivered - - # Method Request was delivered to the method inbox with the corresponding name - assert not named_method_inbox.empty() - assert method_request in named_method_inbox - - # Method Request was NOT delivered to the generic method inbox - assert generic_method_inbox.empty() - - @pytest.mark.it( - "Adds MethodRequest to the generic method request inbox, if no inbox corresponding to the method name exists" - ) - def test_calling_with_unknown_method_adds_method_to_generic_method_inbox( - self, manager, method_request - ): - # Do NOT get a specific named inbox - just the generic one - generic_method_inbox = manager.get_method_request_inbox() - assert generic_method_inbox.empty() - - delivered = manager.route_method_request(method_request) - assert delivered - - # Method Request was delivered to the generic method inbox since method name was unknown - assert not generic_method_inbox.empty() - assert method_request in generic_method_inbox - - @pytest.mark.it( - "Stops adding MethodRequests to the generic method request inbox once an inbox that corresponds to the method name exists" - ) - def test_routes_method_to_generic_method_inbox_until_named_method_inbox_is_created( - self, manager - ): - # Two MethodRequests for the SAME method name - method_name = "some_method" - method_request1 = MethodRequest( - request_id="1", name=method_name, payload="{'key': 'value'}" - ) - method_request2 = MethodRequest( - request_id="2", name=method_name, payload="{'key': 'value'}" - ) - - # Do NOT get a specific named inbox - just the generic one - generic_method_inbox = manager.get_method_request_inbox() - assert generic_method_inbox.empty() - - # Route the first method request - delivered_request1 = manager.route_method_request(method_request1) - assert delivered_request1 - - # Method Request 1 was delivered to the generic method inbox since the method name was unknown - assert method_request1 in generic_method_inbox - - # Get an inbox for the specific method name - named_method_inbox = manager.get_method_request_inbox(method_name) - assert named_method_inbox.empty() - - # Route the second method request - delivered_request2 = manager.route_method_request(method_request2) - assert delivered_request2 - - # Method Request 2 was delivered to its corresponding named inbox since the method name is known - assert method_request2 in named_method_inbox - assert method_request2 not in generic_method_inbox diff --git a/tests/unit/iothub/test_sync_clients.py b/tests/unit/iothub/test_sync_clients.py deleted file mode 100644 index 19e223fa3..000000000 --- a/tests/unit/iothub/test_sync_clients.py +++ /dev/null @@ -1,2743 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -import threading -import time -import urllib -from azure.iot.device.iothub import IoTHubDeviceClient, IoTHubModuleClient -from azure.iot.device import exceptions as client_exceptions -from azure.iot.device.common.auth import sastoken as st -from azure.iot.device.iothub.pipeline import constant as pipeline_constant -from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions -from azure.iot.device.iothub.pipeline import IoTHubPipelineConfig -from azure.iot.device.iothub.models import Message, MethodRequest -from azure.iot.device.iothub.sync_inbox import SyncClientInbox -from azure.iot.device.iothub.abstract_clients import ( - RECEIVE_TYPE_NONE_SET, - RECEIVE_TYPE_HANDLER, - RECEIVE_TYPE_API, -) -from azure.iot.device import constant as device_constant -from .shared_client_tests import ( - SharedIoTHubClientInstantiationTests, - SharedIoTHubClientPROPERTYHandlerTests, - SharedIoTHubClientPROPERTYReceiverHandlerTests, - SharedIoTHubClientPROPERTYConnectedTests, - SharedIoTHubClientOCCURRENCEConnectTests, - SharedIoTHubClientOCCURRENCEDisconnectTests, - SharedIoTHubClientOCCURRENCENewSastokenRequired, - SharedIoTHubClientOCCURRENCEBackgroundException, - SharedIoTHubClientCreateFromConnectionStringTests, - SharedIoTHubDeviceClientCreateFromSymmetricKeyTests, - SharedIoTHubDeviceClientCreateFromSastokenTests, - SharedIoTHubDeviceClientCreateFromX509CertificateTests, - SharedIoTHubModuleClientCreateFromX509CertificateTests, - SharedIoTHubModuleClientCreateFromSastokenTests, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, -) - -logging.basicConfig(level=logging.DEBUG) - - -################## -# INFRASTRUCTURE # -################## -# TODO: now that there are EventedCallbacks, tests should be updated to test their use -# (which is much simpler than this infrastructure) -class WaitsForEventCompletion(object): - def add_event_completion_checks(self, mocker, pipeline_function, args=[], kwargs={}): - event_init_mock = mocker.patch.object(threading, "Event") - event_mock = event_init_mock.return_value - - def check_callback_completes_event(): - # Assert exactly one Event was instantiated so we know the following asserts - # are related to the code under test ONLY - assert event_init_mock.call_count == 1 - - # Assert waiting for Event to complete - assert event_mock.wait.call_count == 1 - assert event_mock.set.call_count == 0 - - # Manually trigger callback - cb = pipeline_function.call_args[1]["callback"] - cb(*args, **kwargs) - - # Assert Event is now completed - assert event_mock.set.call_count == 1 - - event_mock.wait.side_effect = check_callback_completes_event - - -########################## -# SHARED CLIENT FIXTURES # -########################## -@pytest.fixture -def handler(): - def _handler_function(arg): - pass - - return _handler_function - - -####################### -# SHARED CLIENT TESTS # -####################### -class SharedClientShutdownTests(WaitsForEventCompletion): - @pytest.mark.it("Performs a client disconnect (and everything that entails)") - def test_calls_disconnect(self, mocker, client): - # We merely check that disconnect is called here. Doing so does several things, which - # are covered by the disconnect tests themselves. Those tests will NOT be duplicated here - client.disconnect = mocker.MagicMock() - assert client.disconnect.call_count == 0 - - client.shutdown() - - assert client.disconnect.call_count == 1 - - @pytest.mark.it("Begins a 'shutdown' pipeline operation") - def test_calls_pipeline_shutdown(self, mocker, client, mqtt_pipeline): - # mock out implicit disconnect - client.disconnect = mocker.MagicMock() - - client.shutdown() - assert mqtt_pipeline.shutdown.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of the 'shutdown' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb - ): - self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.shutdown - ) - # mock out implicit disconnect - client_manual_cb.disconnect = mocker.MagicMock() - - client_manual_cb.shutdown() - - @pytest.mark.it( - "Raises a client error if the `shutdown` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - # The only other expected errors are unexpected ones. - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, pipeline_error, client_error - ): - # mock out implicit disconnect - client_manual_cb.disconnect = mocker.MagicMock() - - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.shutdown, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.shutdown() - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it( - "Stops the client event handlers after the `shutdown` pipeline operation is complete" - ) - def test_stops_client_event_handlers(self, mocker, client, mqtt_pipeline): - # mock out implicit disconnect - client.disconnect = mocker.MagicMock() - # Spy on handler manager stop. Note that while it does get called twice in shutdown, it - # only happens once here because we have mocked disconnect (where first stoppage) occurs - hm_stop_spy = mocker.spy(client._handler_manager, "stop") - - def check_handlers_and_complete(callback): - assert hm_stop_spy.call_count == 0 - callback() - - mqtt_pipeline.shutdown.side_effect = check_handlers_and_complete - - client.shutdown() - - assert hm_stop_spy.call_count == 1 - assert hm_stop_spy.call_args == mocker.call(receiver_handlers_only=False) - - -class SharedClientConnectTests(WaitsForEventCompletion): - @pytest.mark.it("Begins a 'connect' pipeline operation") - def test_calls_pipeline_connect(self, client, mqtt_pipeline): - client.connect() - assert mqtt_pipeline.connect.call_count == 1 - - @pytest.mark.it("Waits for the completion of the 'connect' pipeline operation before returning") - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb - ): - self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.connect - ) - client_manual_cb.connect() - - @pytest.mark.it( - "Raises a client error if the `connect` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.TlsExchangeAuthError, - client_exceptions.ClientError, - id="TlsExchangeAuthError->ClientError", - ), - pytest.param( - pipeline_exceptions.ProtocolProxyError, - client_exceptions.ClientError, - id="ProtocolProxyError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout->OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.connect, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.connect() - assert e_info.value.__cause__ is my_pipeline_error - - -class SharedClientDisconnectTests(WaitsForEventCompletion): - @pytest.mark.it( - "Runs a 'disconnect' pipeline operation, stops the receiver handlers, then runs a second 'disconnect' pipeline operation" - ) - def test_calls_pipeline_disconnect(self, mocker, client, mqtt_pipeline): - manager_mock = mocker.MagicMock() - client._handler_manager = mocker.MagicMock() - manager_mock.attach_mock(mqtt_pipeline.disconnect, "disconnect") - manager_mock.attach_mock(client._handler_manager.stop, "stop") - - client.disconnect() - assert mqtt_pipeline.disconnect.call_count == 2 - assert client._handler_manager.stop.call_count == 1 - assert manager_mock.mock_calls == [ - mocker.call.disconnect(callback=mocker.ANY), - mocker.call.stop(receiver_handlers_only=True), - mocker.call.disconnect(callback=mocker.ANY), - ] - - @pytest.mark.it( - "Waits for the completion of both 'disconnect' pipeline operations before returning" - ) - def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline): - cb_mock1 = mocker.MagicMock() - cb_mock2 = mocker.MagicMock() - mocker.patch("azure.iot.device.iothub.sync_clients.EventedCallback").side_effect = [ - cb_mock1, - cb_mock2, - ] - - client.disconnect() - - # Disconnect called twice - assert mqtt_pipeline.disconnect.call_count == 2 - # Assert callbacks sent to pipeline - assert mqtt_pipeline.disconnect.call_args_list[0][1]["callback"] is cb_mock1 - assert mqtt_pipeline.disconnect.call_args_list[1][1]["callback"] is cb_mock2 - # Assert callback completions were waited upon - assert cb_mock1.wait_for_completion.call_count == 1 - assert cb_mock2.wait_for_completion.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `disconnect` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.disconnect, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.disconnect() - assert e_info.value.__cause__ is my_pipeline_error - - -class SharedClientUpdateSasTokenTests(WaitsForEventCompletion): - # NOTE: Classes that inherit from this class must define some additional fixtures not included - # here, which will be specific to a device or module: - # - sas_config: returns an IoTHubPipelineConfiguration configured for Device/Module - # - uri: A uri that matches the uri in the SAS from sas_token_string fixture - # - nonmatching_uri: A uri that does NOT match to the uri in the SAS from sas_token_string - # - invalid_uri: A uri that is invalid (poorly formed, missing data, etc.) - - @pytest.fixture - def device_id(self, sas_token_string): - # NOTE: This is kind of unconventional, but this is the easiest way to extract the - # device id from a sastoken string - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - device_id = token_uri_pieces[2] - return device_id - - @pytest.fixture - def hostname(self, sas_token_string): - # NOTE: This is kind of unconventional, but this is the easiest way to extract the - # hostname from a sastoken string - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - hostname = token_uri_pieces[0] - return hostname - - @pytest.fixture - def sas_client(self, client_class, mqtt_pipeline, http_pipeline, sas_config): - """Client configured as if using user-provided, non-renewable SAS auth""" - mqtt_pipeline.pipeline_configuration = sas_config - http_pipeline.pipeline_configuration = sas_config - return client_class(mqtt_pipeline, http_pipeline) - - @pytest.fixture - def sas_client_manual_cb( - self, client_class, mqtt_pipeline_manual_cb, http_pipeline_manual_cb, sas_config - ): - mqtt_pipeline_manual_cb.pipeline_configuration = sas_config - http_pipeline_manual_cb.pipeline_configuration = sas_config - return client_class(mqtt_pipeline_manual_cb, http_pipeline_manual_cb) - - @pytest.fixture - def new_sas_token_string(self, uri): - # New SASToken String that matches old device id and hostname - signature = "AvCQCS7uVk8Lxau7rBs/jek4iwENIwLwpEV7NIJySc0=" - new_token_string = "SharedAccessSignature sr={uri}&sig={signature}&se={expiry}".format( - uri=urllib.parse.quote(uri, safe=""), - signature=urllib.parse.quote(signature, safe=""), - expiry=int(time.time()) + 3600, - ) - return new_token_string - - @pytest.mark.it( - "Creates a new NonRenewableSasToken and sets it on the PipelineConfig, if the new SAS Token string matches the existing SAS Token's information" - ) - def test_updates_token_if_match_vals(self, sas_client, new_sas_token_string): - - old_sas_token_string = str(sas_client._mqtt_pipeline.pipeline_configuration.sastoken) - - # Update to new token - sas_client.update_sastoken(new_sas_token_string) - - # Sastoken was updated - assert ( - str(sas_client._mqtt_pipeline.pipeline_configuration.sastoken) == new_sas_token_string - ) - assert ( - str(sas_client._mqtt_pipeline.pipeline_configuration.sastoken) != old_sas_token_string - ) - - @pytest.mark.it("Begins a 'reauthorize connection' pipeline operation") - def test_calls_pipeline_reauthorize(self, sas_client, new_sas_token_string, mqtt_pipeline): - sas_client.update_sastoken(new_sas_token_string) - assert mqtt_pipeline.reauthorize_connection.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of the 'reauthorize connection' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, sas_client_manual_cb, mqtt_pipeline_manual_cb, new_sas_token_string - ): - self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.reauthorize_connection - ) - sas_client_manual_cb.update_sastoken(new_sas_token_string) - - @pytest.mark.it( - "Raises a ClientError if the 'reauthorize connection' pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.TlsExchangeAuthError, - client_exceptions.ClientError, - id="TlsExchangeAuthError->ClientError", - ), - pytest.param( - pipeline_exceptions.ProtocolProxyError, - client_exceptions.ClientError, - id="ProtocolProxyError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout->OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, - mocker, - sas_client_manual_cb, - mqtt_pipeline_manual_cb, - new_sas_token_string, - client_error, - pipeline_error, - ): - # NOTE: If/When the MQTT pipeline is updated so that the reauthorize op waits for - # reconnection in order to return (currently it just waits for the disconnect), - # there will need to be additional connect-related errors in the parametrization. - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.reauthorize_connection, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - sas_client_manual_cb.update_sastoken(new_sas_token_string) - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it( - "Raises a ClientError if the client was created with an X509 certificate instead of SAS" - ) - def test_created_with_x509(self, mocker, sas_client, new_sas_token_string): - # Modify client to seem as if created with X509 - x509_client = sas_client - x509_client._mqtt_pipeline.pipeline_configuration.sastoken = None - x509_client._mqtt_pipeline.pipeline_configuration.x509 = mocker.MagicMock() - - with pytest.raises(client_exceptions.ClientError): - x509_client.update_sastoken(new_sas_token_string) - - @pytest.mark.it( - "Raises a ClientError if the client was created with a renewable, non-user provided SAS (e.g. from connection string, symmetric key, etc.)" - ) - def test_created_with_renewable_sas(self, mocker, uri, sas_client, new_sas_token_string): - # Modify client to seem as if created with renewable SAS - mock_signing_mechanism = mocker.MagicMock() - mock_signing_mechanism.sign.return_value = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" - renewable_token = st.RenewableSasToken(uri, mock_signing_mechanism) - sas_client._mqtt_pipeline.pipeline_configuration.sastoken = renewable_token - - # Client fails - with pytest.raises(client_exceptions.ClientError): - sas_client.update_sastoken(new_sas_token_string) - - @pytest.mark.it("Raises a ValueError if there is an error creating a new NonRenewableSasToken") - def test_token_error(self, mocker, sas_client, new_sas_token_string): - # NOTE: specific inputs that could cause this are tested in the sastoken test module - sastoken_mock = mocker.patch.object(st.NonRenewableSasToken, "__init__") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - sas_client.update_sastoken(new_sas_token_string) - assert e_info.value.__cause__ is token_err - - @pytest.mark.it("Raises ValueError if the provided SAS token string has already expired") - def test_expired_token(self, mocker, uri, sas_client, hostname, device_id): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote(uri, safe=""), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() - 3600), # expired - ) - - with pytest.raises(ValueError): - sas_client.update_sastoken(sastoken_str) - - @pytest.mark.it( - "Raises ValueError if the provided SAS token string does not match the previous SAS details" - ) - def test_nonmatching_uri_in_new_token(self, sas_client, nonmatching_uri): - signature = "AvCQCS7uVk8Lxau7rBs/jek4iwENIwLwpEV7NIJySc0=" - sastoken_str = "SharedAccessSignature sr={uri}&sig={signature}&se={expiry}".format( - uri=urllib.parse.quote(nonmatching_uri, safe=""), - signature=urllib.parse.quote(signature), - expiry=int(time.time()) + 3600, - ) - - with pytest.raises(ValueError): - sas_client.update_sastoken(sastoken_str) - - @pytest.mark.it("Raises ValueError if the provided SAS token string has an invalid URI") - def test_raises_value_error_invalid_uri(self, mocker, sas_client, invalid_uri): - sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( - resource=urllib.parse.quote(invalid_uri, safe=""), - signature=urllib.parse.quote("ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", safe=""), - expiry=int(time.time() + 3600), - ) - - with pytest.raises(ValueError): - sas_client.update_sastoken(sastoken_str) - - -class SharedClientSendD2CMessageTests(WaitsForEventCompletion): - @pytest.mark.it("Begins a 'send_message' MQTTPipeline operation") - def test_calls_pipeline_send_message(self, client, mqtt_pipeline, message): - client.send_message(message) - assert mqtt_pipeline.send_message.call_count == 1 - assert mqtt_pipeline.send_message.call_args[0][0] is message - - @pytest.mark.it( - "Waits for the completion of the 'send_message' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, message - ): - self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.send_message - ) - client_manual_cb.send_message(message) - - @pytest.mark.it( - "Raises a client error if the `send_message` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, - mocker, - client_manual_cb, - mqtt_pipeline_manual_cb, - message, - pipeline_error, - client_error, - ): - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.send_message, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.send_message(message) - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it( - "Wraps 'message' input parameter in a Message object if it is not a Message object" - ) - @pytest.mark.parametrize( - "message_input", - [ - pytest.param("message", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - def test_wraps_data_in_message_and_calls_pipeline_send_message( - self, client, mqtt_pipeline, message_input - ): - client.send_message(message_input) - assert mqtt_pipeline.send_message.call_count == 1 - sent_message = mqtt_pipeline.send_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == message_input - - @pytest.mark.it("Raises error when message data size is greater than 256 KB") - def test_raises_error_when_message_data_greater_than_256(self, client, mqtt_pipeline): - data_input = "serpensortia" * 25600 - message = Message(data_input) - with pytest.raises(ValueError) as e_info: - client.send_message(message) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_message.call_count == 0 - - @pytest.mark.it("Raises error when message size is greater than 256 KB") - def test_raises_error_when_message_size_greater_than_256(self, client, mqtt_pipeline): - data_input = "serpensortia" - message = Message(data_input) - message.custom_properties["spell"] = data_input * 25600 - with pytest.raises(ValueError) as e_info: - client.send_message(message) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_message.call_count == 0 - - @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") - def test_raises_error_when_message_data_equal_to_256(self, client, mqtt_pipeline): - data_input = "a" * 262095 - message = Message(data_input) - # This check was put as message class may undergo the default content type encoding change - # and the above calculation will change. - if message.get_size() != device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - assert False - - client.send_message(message) - - assert mqtt_pipeline.send_message.call_count == 1 - sent_message = mqtt_pipeline.send_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == data_input - - -class SharedClientReceiveMethodRequestTests(object): - @pytest.mark.it("Implicitly enables methods feature if not already enabled") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - def test_enables_methods_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline, method_name - ): - mocker.patch.object(SyncClientInbox, "get") # patch this receive_method_request won't block - - # Verify Input Messaging enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - False # Method Requests will appear disabled - ) - client.receive_method_request(method_name) - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.METHODS - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify Input Messaging not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - True # Input Messages will appear enabled - ) - client.receive_method_request(method_name) - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it( - "Returns a MethodRequest from the generic method inbox, if available, when called without method name" - ) - def test_called_without_method_name_returns_method_request_from_generic_method_inbox( - self, mocker, client - ): - request = MethodRequest(request_id="1", name="some_method", payload={"key": "value"}) - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - inbox_mock.get.return_value = request - manager_get_inbox_mock = mocker.patch.object( - target=client._inbox_manager, - attribute="get_method_request_inbox", - return_value=inbox_mock, - ) - - received_request = client.receive_method_request() - assert manager_get_inbox_mock.call_count == 1 - assert manager_get_inbox_mock.call_args == mocker.call(None) - assert inbox_mock.get.call_count == 1 - assert received_request is received_request - - @pytest.mark.it( - "Returns MethodRequest from the corresponding method inbox, if available, when called with a method name" - ) - def test_called_with_method_name_returns_method_request_from_named_method_inbox( - self, mocker, client - ): - method_name = "some_method" - request = MethodRequest(request_id="1", name=method_name, payload={"key": "value"}) - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - inbox_mock.get.return_value = request - manager_get_inbox_mock = mocker.patch.object( - target=client._inbox_manager, - attribute="get_method_request_inbox", - return_value=inbox_mock, - ) - - received_request = client.receive_method_request(method_name) - assert manager_get_inbox_mock.call_count == 1 - assert manager_get_inbox_mock.call_args == mocker.call(method_name) - assert inbox_mock.get.call_count == 1 - assert received_request is received_request - - @pytest.mark.it("Can be called in various modes") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_method_request_can_be_called_in_mode( - self, mocker, client, block, timeout, method_name - ): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - target=client._inbox_manager, - attribute="get_method_request_inbox", - return_value=inbox_mock, - ) - - client.receive_method_request(method_name=method_name, block=block, timeout=timeout) - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=block, timeout=timeout) - - @pytest.mark.it("Defaults to blocking mode with no timeout") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - def test_receive_method_request_default_mode(self, mocker, client, method_name): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - target=client._inbox_manager, - attribute="get_method_request_inbox", - return_value=inbox_mock, - ) - client.receive_method_request(method_name=method_name) - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=True, timeout=None) - - @pytest.mark.it("Blocks until a method request is available, in blocking mode") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - def test_no_method_request_in_inbox_blocking_mode(self, client, method_name): - request = MethodRequest(request_id="1", name=method_name, payload={"key": "value"}) - - inbox = client._inbox_manager.get_method_request_inbox(method_name) - assert inbox.empty() - - def insert_item_after_delay(): - time.sleep(0.01) - inbox.put(request) - - insertion_thread = threading.Thread(target=insert_item_after_delay) - insertion_thread.start() - - received_request = client.receive_method_request(method_name, block=True) - assert received_request is request - # This proves that the blocking happens because 'received_request' can't be - # 'request' until after a 10 millisecond delay on the insert. But because the - # 'received_request' IS 'request', it means that client.receive_method_request - # did not return until after the delay. - - @pytest.mark.it( - "Returns None after a timeout while blocking, in blocking mode with a specified timeout" - ) - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - def test_times_out_waiting_for_message_blocking_mode(self, client, method_name): - result = client.receive_method_request(method_name, block=True, timeout=0.01) - assert result is None - - @pytest.mark.it("Returns None immediately if there are no messages, in nonblocking mode") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - def test_no_message_in_inbox_nonblocking_mode(self, client, method_name): - result = client.receive_method_request(method_name, block=False) - assert result is None - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_not_set(self, mocker, client, method_name, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_method_request_inbox", return_value=inbox_mock - ) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - client.receive_method_request(method_name=method_name, block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_api(self, mocker, client, method_name, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_method_request_inbox", return_value=inbox_mock - ) - - client._receive_type = RECEIVE_TYPE_API - client.receive_method_request(method_name=method_name, block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - @pytest.mark.parametrize( - "method_name", - [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_handler( - self, mocker, client, mqtt_pipeline, method_name, block, timeout - ): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_method_request_inbox", return_value=inbox_mock - ) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - client.receive_method_request(method_name=method_name, block=block, timeout=timeout) - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_mock.get.call_count == 0 - - -class SharedClientSendMethodResponseTests(WaitsForEventCompletion): - @pytest.mark.it("Begins a 'send_method_response' pipeline operation") - def test_send_method_response_calls_pipeline(self, client, mqtt_pipeline, method_response): - - client.send_method_response(method_response) - assert mqtt_pipeline.send_method_response.call_count == 1 - assert mqtt_pipeline.send_method_response.call_args[0][0] is method_response - - @pytest.mark.it( - "Waits for the completion of the 'send_method_response' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, method_response - ): - self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.send_method_response - ) - client_manual_cb.send_method_response(method_response) - - @pytest.mark.it( - "Raises a client error if the `send_method_response` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, - mocker, - client_manual_cb, - mqtt_pipeline_manual_cb, - method_response, - pipeline_error, - client_error, - ): - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.send_method_response, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.send_method_response(method_response) - assert e_info.value.__cause__ is my_pipeline_error - - -class SharedClientGetTwinTests(WaitsForEventCompletion): - @pytest.fixture - def patch_get_twin_to_return_fake_twin(self, fake_twin, mocker, mqtt_pipeline): - def immediate_callback(callback): - callback(twin=fake_twin) - - mocker.patch.object(mqtt_pipeline, "get_twin", side_effect=immediate_callback) - - @pytest.mark.it("Implicitly enables twin messaging feature if not already enabled") - def test_enables_twin_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline, patch_get_twin_to_return_fake_twin, fake_twin - ): - # Verify twin enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False # twin will appear disabled - client.get_twin() - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.TWIN - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify twin not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # twin will appear enabled - client.get_twin() - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Begins a 'get_twin' pipeline operation") - def test_get_twin_calls_pipeline(self, client, mqtt_pipeline): - client.get_twin() - assert mqtt_pipeline.get_twin.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of the 'get_twin' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, fake_twin - ): - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.get_twin, - kwargs={"twin": fake_twin}, - ) - client_manual_cb.get_twin() - - @pytest.mark.it( - "Raises a client error if the `get_twin` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, pipeline_error, client_error - ): - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.get_twin, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.get_twin() - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it("Returns the twin that the pipeline returned") - def test_verifies_twin_returned( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, fake_twin - ): - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.get_twin, - kwargs={"twin": fake_twin}, - ) - returned_twin = client_manual_cb.get_twin() - assert returned_twin == fake_twin - - -class SharedClientPatchTwinReportedPropertiesTests(WaitsForEventCompletion): - @pytest.mark.it("Implicitly enables twin messaging feature if not already enabled") - def test_enables_twin_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline, twin_patch_reported - ): - # patch this so x_get_twin won't block - def immediate_callback(patch, callback): - callback() - - mocker.patch.object( - mqtt_pipeline, "patch_twin_reported_properties", side_effect=immediate_callback - ) - - # Verify twin enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False # twin will appear disabled - client.patch_twin_reported_properties(twin_patch_reported) - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.TWIN - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify twin not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # twin will appear enabled - client.patch_twin_reported_properties(twin_patch_reported) - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Begins a 'patch_twin_reported_properties' pipeline operation") - def test_patch_twin_reported_properties_calls_pipeline( - self, client, mqtt_pipeline, twin_patch_reported - ): - client.patch_twin_reported_properties(twin_patch_reported) - assert mqtt_pipeline.patch_twin_reported_properties.call_count == 1 - assert ( - mqtt_pipeline.patch_twin_reported_properties.call_args[1]["patch"] - is twin_patch_reported - ) - - @pytest.mark.it( - "Waits for the completion of the 'patch_twin_reported_properties' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, twin_patch_reported - ): - self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.patch_twin_reported_properties - ) - client_manual_cb.patch_twin_reported_properties(twin_patch_reported) - - @pytest.mark.it( - "Raises a client error if the `patch_twin_reported_properties` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, - mocker, - client_manual_cb, - mqtt_pipeline_manual_cb, - twin_patch_reported, - pipeline_error, - client_error, - ): - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.patch_twin_reported_properties, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.patch_twin_reported_properties(twin_patch_reported) - assert e_info.value.__cause__ is my_pipeline_error - - -class SharedClientReceiveTwinDesiredPropertiesPatchTests(object): - @pytest.mark.it( - "Implicitly enables Twin desired properties patch feature if not already enabled" - ) - def test_enables_twin_patches_only_if_not_already_enabled(self, mocker, client, mqtt_pipeline): - mocker.patch.object( - SyncClientInbox, "get" - ) # patch this so receive_twin_desired_properties_patch won't block - - # Verify twin patches enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - False # twin patches will appear disabled - ) - client.receive_twin_desired_properties_patch() - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.TWIN_PATCHES - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify twin patches not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # C2D will appear enabled - client.receive_twin_desired_properties_patch() - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Returns a patch from the twin patch inbox, if available") - def test_returns_message_from_twin_patch_inbox(self, mocker, client, twin_patch_desired): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - inbox_mock.get.return_value = twin_patch_desired - manager_get_inbox_mock = mocker.patch.object( - client._inbox_manager, "get_twin_patch_inbox", return_value=inbox_mock - ) - - received_patch = client.receive_twin_desired_properties_patch() - assert manager_get_inbox_mock.call_count == 1 - assert inbox_mock.get.call_count == 1 - assert received_patch is twin_patch_desired - - @pytest.mark.it("Can be called in various modes") - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_can_be_called_in_mode(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_twin_patch_inbox", return_value=inbox_mock) - - client.receive_twin_desired_properties_patch(block=block, timeout=timeout) - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=block, timeout=timeout) - - @pytest.mark.it("Defaults to blocking mode with no timeout") - def test_default_mode(self, mocker, client): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_twin_patch_inbox", return_value=inbox_mock) - - client.receive_twin_desired_properties_patch() - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=True, timeout=None) - - @pytest.mark.it("Blocks until a patch is available, in blocking mode") - def test_no_message_in_inbox_blocking_mode(self, client, twin_patch_desired): - - twin_patch_inbox = client._inbox_manager.get_twin_patch_inbox() - assert twin_patch_inbox.empty() - - def insert_item_after_delay(): - time.sleep(0.01) - twin_patch_inbox.put(twin_patch_desired) - - insertion_thread = threading.Thread(target=insert_item_after_delay) - insertion_thread.start() - - received_patch = client.receive_twin_desired_properties_patch(block=True) - assert received_patch is twin_patch_desired - # This proves that the blocking happens because 'received_patch' can't be - # 'twin_patch_desired' until after a 10 millisecond delay on the insert. But because the - # 'received_patch' IS 'twin_patch_desired', it means that client.receive_twin_desired_properties_patch - # did not return until after the delay. - - @pytest.mark.it( - "Returns None after a timeout while blocking, in blocking mode with a specified timeout" - ) - def test_times_out_waiting_for_message_blocking_mode(self, client): - result = client.receive_twin_desired_properties_patch(block=True, timeout=0.01) - assert result is None - - @pytest.mark.it("Returns None immediately if there are no patches, in nonblocking mode") - def test_no_message_in_inbox_nonblocking_mode(self, client): - result = client.receive_twin_desired_properties_patch(block=False) - assert result is None - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_not_set(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_twin_patch_inbox", return_value=inbox_mock) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - client.receive_twin_desired_properties_patch(block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_api(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_twin_patch_inbox", return_value=inbox_mock) - - client._receive_type = RECEIVE_TYPE_API - client.receive_twin_desired_properties_patch(block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_handler(self, mocker, client, mqtt_pipeline, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_twin_patch_inbox", return_value=inbox_mock) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - client.receive_twin_desired_properties_patch(block=block, timeout=timeout) - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_mock.get.call_count == 0 - - -################ -# DEVICE TESTS # -################ -class IoTHubDeviceClientTestsConfig(object): - @pytest.fixture - def client_class(self): - return IoTHubDeviceClient - - @pytest.fixture - def client(self, mqtt_pipeline, http_pipeline): - """This client automatically resolves callbacks sent to the pipeline. - It should be used for the majority of tests. - """ - return IoTHubDeviceClient(mqtt_pipeline, http_pipeline) - - @pytest.fixture - def client_manual_cb(self, mqtt_pipeline_manual_cb, http_pipeline_manual_cb): - """This client requires manual triggering of the callbacks sent to the pipeline. - It should only be used for tests where manual control fo a callback is required. - """ - return IoTHubDeviceClient(mqtt_pipeline_manual_cb, http_pipeline_manual_cb) - - @pytest.fixture - def connection_string(self, device_connection_string): - """This fixture is parametrized to prove all valid device connection strings. - See client_fixtures.py - """ - return device_connection_string - - @pytest.fixture - def sas_token_string(self, device_sas_token_string): - return device_sas_token_string - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - Instantiation") -class TestIoTHubDeviceClientInstantiation( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientInstantiationTests -): - @pytest.mark.it("Sets on_c2d_message_received handler in the MQTTPipeline") - def test_sets_on_c2d_message_received_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_c2d_message_received is not None - assert ( - client._mqtt_pipeline.on_c2d_message_received == client._inbox_manager.route_c2d_message - ) - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_connection_string()") -class TestIoTHubDeviceClientCreateFromConnectionString( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_sastoken()") -class TestIoTHubDeviceClientCreateFromSastoken( - IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromSastokenTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_symmetric_key()") -class TestIoTHubDeviceClientCreateFromSymmetricKey( - IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromSymmetricKeyTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_x509_certificate()") -class TestIoTHubDeviceClientCreateFromX509Certificate( - IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromX509CertificateTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .shutdown()") -class TestIoTHubDeviceClientShutdown(IoTHubDeviceClientTestsConfig, SharedClientShutdownTests): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .update_sastoken()") -class TestIoTHubDeviceClientUpdateSasToken( - IoTHubDeviceClientTestsConfig, SharedClientUpdateSasTokenTests -): - @pytest.fixture - def sas_config(self, sas_token_string): - """PipelineConfig set up as if using user-provided, non-renewable SAS auth""" - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - hostname = token_uri_pieces[0] - device_id = token_uri_pieces[2] - sas_config = IoTHubPipelineConfig(hostname=hostname, device_id=device_id, sastoken=sastoken) - return sas_config - - @pytest.fixture - def sas_client(self, mqtt_pipeline, http_pipeline, sas_config): - """Client configured as if using user-provided, non-renewable SAS auth""" - mqtt_pipeline.pipeline_configuration = sas_config - http_pipeline.pipeline_configuration = sas_config - return IoTHubDeviceClient(mqtt_pipeline, http_pipeline) - - @pytest.fixture - def uri(self, hostname, device_id): - return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) - - @pytest.fixture(params=["Nonmatching Device ID", "Nonmatching Hostname"]) - def nonmatching_uri(self, request, device_id, hostname): - # NOTE: It would be preferable to have this as a parametrization on a test rather than a - # fixture, however, we need to use the device_id and hostname fixtures in order to ensure - # tests don't break when other fixtures change, and you can't include fixtures in a - # parametrization, so this also has to be a fixture - uri_format = "{hostname}/devices/{device_id}" - if request.param == "Nonmatching Device ID": - return uri_format.format(hostname=hostname, device_id="nonmatching_device") - else: - return uri_format.format(hostname="nonmatching_hostname", device_id=device_id) - - @pytest.fixture( - params=["Too short", "Too long", "Incorrectly formatted device notation", "Module URI"] - ) - def invalid_uri(self, request, device_id, hostname): - # NOTE: As in the nonmatching_uri fixture above, this is a workaround for parametrization - # that allows the usage of other fixtures in the parametrized value. Weird pattern, but - # necessary to ensure stability of the tests over time. - if request.param == "Too short": - # Doesn't have device ID - return hostname + "/devices" - elif request.param == "Too long": - # Extraneous value at the end - return "{}/devices/{}/somethingElse".format(hostname, device_id) - elif request.param == "Incorrectly formatted device notation": - # Doesn't have '/devices/' - return "{}/not-devices/{}".format(hostname, device_id) - else: - # Valid... for a Module... but this is a Device - return "{}/devices/{}/modules/my_module".format(hostname, device_id) - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .connect()") -class TestIoTHubDeviceClientConnect(IoTHubDeviceClientTestsConfig, SharedClientConnectTests): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .disconnect()") -class TestIoTHubDeviceClientDisconnect(IoTHubDeviceClientTestsConfig, SharedClientDisconnectTests): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .send_message()") -class TestIoTHubDeviceClientSendD2CMessage( - IoTHubDeviceClientTestsConfig, SharedClientSendD2CMessageTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .receive_message()") -class TestIoTHubDeviceClientReceiveC2DMessage( - IoTHubDeviceClientTestsConfig, WaitsForEventCompletion -): - @pytest.mark.it("Implicitly enables C2D messaging feature if not already enabled") - def test_enables_c2d_messaging_only_if_not_already_enabled(self, mocker, client, mqtt_pipeline): - mocker.patch.object(SyncClientInbox, "get") # patch this so receive_message won't block - - # Verify C2D Messaging enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = False # C2D will appear disabled - client.receive_message() - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.C2D_MSG - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify C2D Messaging not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = True # C2D will appear enabled - client.receive_message() - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Returns a message from the C2D inbox, if available") - def test_returns_message_from_c2d_inbox(self, mocker, client, message): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - inbox_mock.get.return_value = message - manager_get_inbox_mock = mocker.patch.object( - client._inbox_manager, "get_c2d_message_inbox", return_value=inbox_mock - ) - - received_message = client.receive_message() - assert manager_get_inbox_mock.call_count == 1 - assert inbox_mock.get.call_count == 1 - assert received_message is message - - @pytest.mark.it("Can be called in various modes") - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_can_be_called_in_mode(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_c2d_message_inbox", return_value=inbox_mock) - - client.receive_message(block=block, timeout=timeout) - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=block, timeout=timeout) - - @pytest.mark.it("Defaults to blocking mode with no timeout") - def test_default_mode(self, mocker, client): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_c2d_message_inbox", return_value=inbox_mock) - - client.receive_message() - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=True, timeout=None) - - @pytest.mark.it("Blocks until a message is available, in blocking mode") - def test_no_message_in_inbox_blocking_mode(self, client, message): - c2d_inbox = client._inbox_manager.get_c2d_message_inbox() - assert c2d_inbox.empty() - - def insert_item_after_delay(): - time.sleep(0.01) - c2d_inbox.put(message) - - insertion_thread = threading.Thread(target=insert_item_after_delay) - insertion_thread.start() - - received_message = client.receive_message(block=True) - assert received_message is message - # This proves that the blocking happens because 'received_message' can't be - # 'message' until after a 10 millisecond delay on the insert. But because the - # 'received_message' IS 'message', it means that client.receive_message - # did not return until after the delay. - - @pytest.mark.it( - "Returns None after a timeout while blocking, in blocking mode with a specified timeout" - ) - def test_times_out_waiting_for_message_blocking_mode(self, client): - result = client.receive_message(block=True, timeout=0.01) - assert result is None - - @pytest.mark.it("Returns None immediately if there are no messages, in nonblocking mode") - def test_no_message_in_inbox_nonblocking_mode(self, client): - result = client.receive_message(block=False) - assert result is None - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_not_set(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_c2d_message_inbox", return_value=inbox_mock) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - client.receive_message(block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_api(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_c2d_message_inbox", return_value=inbox_mock) - - client._receive_type = RECEIVE_TYPE_API - client.receive_message(block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_handler(self, mocker, client, mqtt_pipeline, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object(client._inbox_manager, "get_c2d_message_inbox", return_value=inbox_mock) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - client.receive_message(block=block, timeout=timeout) - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_mock.get.call_count == 0 - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .receive_method_request()") -class TestIoTHubDeviceClientReceiveMethodRequest( - IoTHubDeviceClientTestsConfig, SharedClientReceiveMethodRequestTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .send_method_response()") -class TestIoTHubDeviceClientSendMethodResponse( - IoTHubDeviceClientTestsConfig, SharedClientSendMethodResponseTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .get_twin()") -class TestIoTHubDeviceClientGetTwin(IoTHubDeviceClientTestsConfig, SharedClientGetTwinTests): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .patch_twin_reported_properties()") -class TestIoTHubDeviceClientPatchTwinReportedProperties( - IoTHubDeviceClientTestsConfig, SharedClientPatchTwinReportedPropertiesTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .receive_twin_desired_properties_patch()") -class TestIoTHubDeviceClientReceiveTwinDesiredPropertiesPatch( - IoTHubDeviceClientTestsConfig, SharedClientReceiveTwinDesiredPropertiesPatchTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .get_storage_info_for_blob()") -class TestIoTHubDeviceClientGetStorageInfo(WaitsForEventCompletion, IoTHubDeviceClientTestsConfig): - @pytest.mark.it("Begins a 'get_storage_info_for_blob' HTTPPipeline operation") - def test_calls_pipeline_get_storage_info_for_blob(self, mocker, client, http_pipeline): - fake_blob_name = "__fake_blob_name__" - client.get_storage_info_for_blob(fake_blob_name) - assert http_pipeline.get_storage_info_for_blob.call_count == 1 - assert http_pipeline.get_storage_info_for_blob.call_args == mocker.call( - fake_blob_name, callback=mocker.ANY - ) - - @pytest.mark.it( - "Waits for the completion of the 'get_storage_info_for_blob' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, http_pipeline_manual_cb - ): - fake_blob_name = "__fake_blob_name__" - - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=http_pipeline_manual_cb.get_storage_info_for_blob, - kwargs={"storage_info": "__fake_storage_info__"}, - ) - - client_manual_cb.get_storage_info_for_blob(fake_blob_name) - - @pytest.mark.it( - "Raises a client error if the `get_storage_info_for_blob` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, http_pipeline_manual_cb, pipeline_error, client_error - ): - fake_blob_name = "__fake_blob_name__" - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=http_pipeline_manual_cb.get_storage_info_for_blob, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.get_storage_info_for_blob(fake_blob_name) - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it("Returns a storage_info object upon successful completion") - def test_returns_storage_info(self, mocker, client, http_pipeline): - fake_blob_name = "__fake_blob_name__" - fake_storage_info = "__fake_storage_info__" - received_storage_info = client.get_storage_info_for_blob(fake_blob_name) - assert http_pipeline.get_storage_info_for_blob.call_count == 1 - assert http_pipeline.get_storage_info_for_blob.call_args == mocker.call( - fake_blob_name, callback=mocker.ANY - ) - - assert ( - received_storage_info is fake_storage_info - ) # Note: the return value this is checking for is defined in client_fixtures.py - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .notify_blob_upload_status()") -class TestIoTHubDeviceClientNotifyBlobUploadStatus( - WaitsForEventCompletion, IoTHubDeviceClientTestsConfig -): - @pytest.mark.it("Begins a 'notify_blob_upload_status' HTTPPipeline operation") - def test_calls_pipeline_notify_blob_upload_status(self, client, http_pipeline): - correlation_id = "__fake_correlation_id__" - is_success = "__fake_is_success__" - status_code = "__fake_status_code__" - status_description = "__fake_status_description__" - client.notify_blob_upload_status( - correlation_id, is_success, status_code, status_description - ) - kwargs = http_pipeline.notify_blob_upload_status.call_args[1] - assert http_pipeline.notify_blob_upload_status.call_count == 1 - assert kwargs["correlation_id"] is correlation_id - assert kwargs["is_success"] is is_success - assert kwargs["status_code"] is status_code - assert kwargs["status_description"] is status_description - - @pytest.mark.it( - "Waits for the completion of the 'notify_blob_upload_status' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, http_pipeline_manual_cb - ): - correlation_id = "__fake_correlation_id__" - is_success = "__fake_is_success__" - status_code = "__fake_status_code__" - status_description = "__fake_status_description__" - self.add_event_completion_checks( - mocker=mocker, pipeline_function=http_pipeline_manual_cb.notify_blob_upload_status - ) - - client_manual_cb.notify_blob_upload_status( - correlation_id, is_success, status_code, status_description - ) - - @pytest.mark.it( - "Raises a client error if the `notify_blob_upload_status` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, http_pipeline_manual_cb, pipeline_error, client_error - ): - correlation_id = "__fake_correlation_id__" - is_success = "__fake_is_success__" - status_code = "__fake_status_code__" - status_description = "__fake_status_description__" - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=http_pipeline_manual_cb.notify_blob_upload_status, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.notify_blob_upload_status( - correlation_id, is_success, status_code, status_description - ) - assert e_info.value.__cause__ is my_pipeline_error - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .on_message_received") -class TestIoTHubDeviceClientPROPERTYOnMessageReceivedHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_message_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.C2D_MSG - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .on_method_request_received") -class TestIoTHubDeviceClientPROPERTYOnMethodRequestReceivedHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_method_request_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.METHODS - - -@pytest.mark.describe( - "IoTHubDeviceClient (Synchronous) - PROPERTY .on_twin_desired_properties_patch_received" -) -class TestIoTHubDeviceClientPROPERTYOnTwinDesiredPropertiesPatchReceivedHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_twin_desired_properties_patch_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.TWIN_PATCHES - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .on_connection_state_change") -class TestIoTHubDeviceClientPROPERTYOnConnectionStateChangeHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_connection_state_change" - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .on_new_sastoken_required") -class TestIoTHubDeviceClientPROPERTYOnNewSastokenRequiredHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_new_sastoken_required" - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .on_background_exception") -class TestIoTHubDeviceClientPROPERTYOnBackgroundExceptionHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_background_exception" - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .connected") -class TestIoTHubDeviceClientPROPERTYConnected( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - OCCURRENCE: Connect") -class TestIoTHubDeviceClientOCCURRENCEConnect( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCEConnectTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - OCCURRENCE: Disconnect") -class TestIoTHubDeviceClientOCCURRENCEDisconnect( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCEDisconnectTests -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - OCCURRENCE: New Sastoken Required") -class TestIoTHubDeviceClientOCCURRENCENewSastokenRequired( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCENewSastokenRequired -): - pass - - -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - OCCURRENCE: Background Exception") -class TestIoTHubDeviceClientOCCURRENCEBackgroundException( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCEBackgroundException -): - pass - - -################ -# MODULE TESTS # -################ -class IoTHubModuleClientTestsConfig(object): - @pytest.fixture - def client_class(self): - return IoTHubModuleClient - - @pytest.fixture - def client(self, mqtt_pipeline, http_pipeline): - """This client automatically resolves callbacks sent to the pipeline. - It should be used for the majority of tests. - """ - return IoTHubModuleClient(mqtt_pipeline, http_pipeline) - - @pytest.fixture - def client_manual_cb(self, mqtt_pipeline_manual_cb, http_pipeline_manual_cb): - """This client requires manual triggering of the callbacks sent to the pipeline. - It should only be used for tests where manual control fo a callback is required. - """ - return IoTHubModuleClient(mqtt_pipeline_manual_cb, http_pipeline_manual_cb) - - @pytest.fixture - def connection_string(self, module_connection_string): - """This fixture is parametrized to prove all valid device connection strings. - See client_fixtures.py - """ - return module_connection_string - - @pytest.fixture - def sas_token_string(self, module_sas_token_string): - return module_sas_token_string - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - Instantiation") -class TestIoTHubModuleClientInstantiation( - IoTHubModuleClientTestsConfig, SharedIoTHubClientInstantiationTests -): - @pytest.mark.it("Sets on_input_message_received handler in the MQTTPipeline") - def test_sets_on_input_message_received_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_input_message_received is not None - assert ( - client._mqtt_pipeline.on_input_message_received - == client._inbox_manager.route_input_message - ) - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_connection_string()") -class TestIoTHubModuleClientCreateFromConnectionString( - IoTHubModuleClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_sastoken()") -class TestIoTHubModuleClientCreateFromSastoken( - IoTHubModuleClientTestsConfig, SharedIoTHubModuleClientCreateFromSastokenTests -): - pass - - -@pytest.mark.describe( - "IoTHubModuleClient (Synchronous) - .create_from_edge_environment() -- Edge Container Environment" -) -class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( - IoTHubModuleClientTestsConfig, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, -): - pass - - -@pytest.mark.describe( - "IoTHubModuleClient (Synchronous) - .create_from_edge_environment() -- Edge Local Debug Environment" -) -class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv( - IoTHubModuleClientTestsConfig, - SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_x509_certificate()") -class TestIoTHubModuleClientCreateFromX509Certificate( - IoTHubModuleClientTestsConfig, SharedIoTHubModuleClientCreateFromX509CertificateTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .shutdown()") -class TestIoTHubModuleClientShutdown(IoTHubModuleClientTestsConfig, SharedClientShutdownTests): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .update_sastoken()") -class TestIoTHubModuleClientUpdateSasToken( - IoTHubModuleClientTestsConfig, SharedClientUpdateSasTokenTests -): - @pytest.fixture - def module_id(self, sas_token_string): - # NOTE: This is kind of unconventional, but this is the easiest way to extract the - # module id from a sastoken string - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - module_id = token_uri_pieces[4] - return module_id - - @pytest.fixture - def sas_config(self, sas_token_string): - """PipelineConfig set up as if using user-provided, non-renewable SAS auth""" - sastoken = st.NonRenewableSasToken(sas_token_string) - token_uri_pieces = sastoken.resource_uri.split("/") - hostname = token_uri_pieces[0] - device_id = token_uri_pieces[2] - module_id = token_uri_pieces[4] - sas_config = IoTHubPipelineConfig( - hostname=hostname, device_id=device_id, module_id=module_id, sastoken=sastoken - ) - return sas_config - - @pytest.fixture - def uri(self, hostname, device_id, module_id): - return "{hostname}/devices/{device_id}/modules/{module_id}".format( - hostname=hostname, device_id=device_id, module_id=module_id - ) - - @pytest.fixture( - params=["Nonmatching Device ID", "Nonmatching Module ID", "Nonmatching Hostname"] - ) - def nonmatching_uri(self, request, device_id, module_id, hostname): - # NOTE: It would be preferable to have this as a parametrization on a test rather than a - # fixture, however, we need to use the device_id and hostname fixtures in order to ensure - # tests don't break when other fixtures change, and you can't include fixtures in a - # parametrization, so this also has to be a fixture - uri_format = "{hostname}/devices/{device_id}/modules/{module_id}" - if request.param == "Nonmatching Device ID": - return uri_format.format( - hostname=hostname, device_id="nonmatching_device", module_id=module_id - ) - elif request.param == "Nonmatching Module ID": - return uri_format.format( - hostname=hostname, device_id=device_id, module_id="nonmatching_module" - ) - else: - return uri_format.format( - hostname="nonmatching_hostname", device_id=device_id, module_id=module_id - ) - - @pytest.fixture( - params=[ - "Too short", - "Too long", - "Incorrectly formatted device notation", - "Incorrectly formatted module notation", - "Device URI", - ] - ) - def invalid_uri(self, request, device_id, module_id, hostname): - # NOTE: As in the nonmatching_uri fixture above, this is a workaround for parametrization - # that allows the usage of other fixtures in the parametrized value. Weird pattern, but - # necessary to ensure stability of the tests over time. - if request.param == "Too short": - # Doesn't have module ID - return "{}/devices/{}/modules".format(hostname, device_id) - elif request.param == "Too long": - # Extraneous value at the end - return "{}/devices/{}/modules/{}/somethingElse".format(hostname, device_id, module_id) - elif request.param == "Incorrectly formatted device notation": - # Doesn't have '/devices/' - return "{}/not-devices/{}/modules/{}".format(hostname, device_id, module_id) - elif request.param == "Incorrectly formatted module notation": - # Doesn't have '/modules/' - return "{}/devices/{}/not-modules/{}".format(hostname, device_id, module_id) - else: - # Valid... for a Device... but this is a Module - return "{}/devices/{}/".format(hostname, device_id) - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .connect()") -class TestIoTHubModuleClientConnect(IoTHubModuleClientTestsConfig, SharedClientConnectTests): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .disconnect()") -class TestIoTHubModuleClientDisconnect(IoTHubModuleClientTestsConfig, SharedClientDisconnectTests): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .send_message()") -class TestIoTHubNModuleClientSendD2CMessage( - IoTHubModuleClientTestsConfig, SharedClientSendD2CMessageTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .send_message_to_output()") -class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig, WaitsForEventCompletion): - @pytest.mark.it("Begins a 'send_output_message' pipeline operation") - def test_calls_pipeline_send_message_to_output(self, client, mqtt_pipeline, message): - output_name = "some_output" - client.send_message_to_output(message, output_name) - assert mqtt_pipeline.send_output_message.call_count == 1 - assert mqtt_pipeline.send_output_message.call_args[0][0] is message - assert message.output_name == output_name - - @pytest.mark.it( - "Waits for the completion of the 'send_output_message' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, message - ): - self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.send_output_message - ) - output_name = "some_output" - client_manual_cb.send_message_to_output(message, output_name) - - @pytest.mark.it( - "Raises a client error if the `send_out_event` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.NoConnectionError, - client_exceptions.NoConnectionError, - id="NoConnectionError->NoConnectionError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout -> OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, - mocker, - client_manual_cb, - mqtt_pipeline_manual_cb, - message, - pipeline_error, - client_error, - ): - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.send_output_message, - kwargs={"error": my_pipeline_error}, - ) - output_name = "some_output" - with pytest.raises(client_error) as e_info: - client_manual_cb.send_message_to_output(message, output_name) - assert e_info.value.__cause__ is my_pipeline_error - - @pytest.mark.it( - "Wraps 'message' input parameter in Message object if it is not a Message object" - ) - @pytest.mark.parametrize( - "message_input", - [ - pytest.param("message", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - def test_send_message_to_output_calls_pipeline_wraps_data_in_message( - self, client, mqtt_pipeline, message_input - ): - output_name = "some_output" - client.send_message_to_output(message_input, output_name) - assert mqtt_pipeline.send_output_message.call_count == 1 - sent_message = mqtt_pipeline.send_output_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == message_input - - @pytest.mark.it("Raises error when message data size is greater than 256 KB") - def test_raises_error_when_message_to_output_data_greater_than_256(self, client, mqtt_pipeline): - output_name = "some_output" - data_input = "serpensortia" * 256000 - message = Message(data_input) - with pytest.raises(ValueError) as e_info: - client.send_message_to_output(message, output_name) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_message.call_count == 0 - - @pytest.mark.it("Raises error when message size is greater than 256 KB") - def test_raises_error_when_message_to_output_size_greater_than_256(self, client, mqtt_pipeline): - output_name = "some_output" - data_input = "serpensortia" - message = Message(data_input) - message.custom_properties["spell"] = data_input * 256000 - with pytest.raises(ValueError) as e_info: - client.send_message_to_output(message, output_name) - assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_message.call_count == 0 - - @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") - def test_raises_error_when_message_to_output_data_equal_to_256(self, client, mqtt_pipeline): - output_name = "some_output" - data_input = "a" * 262095 - message = Message(data_input) - # This check was put as message class may undergo the default content type encoding change - # and the above calculation will change. - if message.get_size() != device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: - assert False - - client.send_message_to_output(message, output_name) - - assert mqtt_pipeline.send_output_message.call_count == 1 - sent_message = mqtt_pipeline.send_output_message.call_args[0][0] - assert isinstance(sent_message, Message) - assert sent_message.data == data_input - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .receive_message_on_input()") -class TestIoTHubModuleClientReceiveInputMessage(IoTHubModuleClientTestsConfig): - @pytest.mark.it("Implicitly enables input messaging feature if not already enabled") - def test_enables_input_messaging_only_if_not_already_enabled( - self, mocker, client, mqtt_pipeline - ): - mocker.patch.object( - SyncClientInbox, "get" - ) # patch this receive_message_on_input won't block - input_name = "some_input" - - # Verify Input Messaging enabled if not enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - False # Input Messages will appear disabled - ) - client.receive_message_on_input(input_name) - assert mqtt_pipeline.enable_feature.call_count == 1 - assert mqtt_pipeline.enable_feature.call_args[0][0] == pipeline_constant.INPUT_MSG - - mqtt_pipeline.enable_feature.reset_mock() - - # Verify Input Messaging not enabled if already enabled - mqtt_pipeline.feature_enabled.__getitem__.return_value = ( - True # Input Messages will appear enabled - ) - client.receive_message_on_input(input_name) - assert mqtt_pipeline.enable_feature.call_count == 0 - - @pytest.mark.it("Returns a message from the input inbox, if available") - def test_returns_message_from_input_inbox(self, mocker, client, message): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - inbox_mock.get.return_value = message - manager_get_inbox_mock = mocker.patch.object( - client._inbox_manager, "get_input_message_inbox", return_value=inbox_mock - ) - - input_name = "some_input" - received_message = client.receive_message_on_input(input_name) - assert manager_get_inbox_mock.call_count == 1 - assert manager_get_inbox_mock.call_args == mocker.call(input_name) - assert inbox_mock.get.call_count == 1 - assert received_message is message - - @pytest.mark.it("Can be called in various modes") - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_can_be_called_in_mode(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_input_message_inbox", return_value=inbox_mock - ) - - input_name = "some_input" - client.receive_message_on_input(input_name, block=block, timeout=timeout) - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=block, timeout=timeout) - - @pytest.mark.it("Defaults to blocking mode with no timeout") - def test_default_mode(self, mocker, client): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_input_message_inbox", return_value=inbox_mock - ) - - input_name = "some_input" - client.receive_message_on_input(input_name) - assert inbox_mock.get.call_count == 1 - assert inbox_mock.get.call_args == mocker.call(block=True, timeout=None) - - @pytest.mark.it("Blocks until a message is available, in blocking mode") - def test_no_message_in_inbox_blocking_mode(self, client, message): - input_name = "some_input" - - input_inbox = client._inbox_manager.get_input_message_inbox(input_name) - assert input_inbox.empty() - - def insert_item_after_delay(): - time.sleep(0.01) - input_inbox.put(message) - - insertion_thread = threading.Thread(target=insert_item_after_delay) - insertion_thread.start() - - received_message = client.receive_message_on_input(input_name, block=True) - assert received_message is message - # This proves that the blocking happens because 'received_message' can't be - # 'message' until after a 10 millisecond delay on the insert. But because the - # 'received_message' IS 'message', it means that client.receive_message_on_input - # did not return until after the delay. - - @pytest.mark.it( - "Returns None after a timeout while blocking, in blocking mode with a specified timeout" - ) - def test_times_out_waiting_for_message_blocking_mode(self, client): - input_name = "some_input" - result = client.receive_message_on_input(input_name, block=True, timeout=0.01) - assert result is None - - @pytest.mark.it("Returns None immediately if there are no messages, in nonblocking mode") - def test_no_message_in_inbox_nonblocking_mode(self, client): - input_name = "some_input" - result = client.receive_message_on_input(input_name, block=False) - assert result is None - - @pytest.mark.it("Locks the client to API Receive Mode if the receive mode has not yet been set") - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_not_set(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_input_message_inbox", return_value=inbox_mock - ) - - assert client._receive_type is RECEIVE_TYPE_NONE_SET - client.receive_message_on_input(input_name="some_input", block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Does not modify the client receive mode if it has already been set to API Receive Mode" - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_api(self, mocker, client, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_input_message_inbox", return_value=inbox_mock - ) - - client._receive_type = RECEIVE_TYPE_API - client.receive_message_on_input(input_name="some_input", block=block, timeout=timeout) - assert client._receive_type is RECEIVE_TYPE_API - - @pytest.mark.it( - "Raises a ClientError and does nothing else if the client receive mode has been set to Handler Receive Mode" - ) - @pytest.mark.parametrize( - "block,timeout", - [ - pytest.param(True, None, id="Blocking, no timeout"), - pytest.param(True, 10, id="Blocking with timeout"), - pytest.param(False, None, id="Nonblocking"), - ], - ) - def test_receive_mode_set_handler(self, mocker, client, mqtt_pipeline, block, timeout): - inbox_mock = mocker.MagicMock(autospec=SyncClientInbox) - mocker.patch.object( - client._inbox_manager, "get_input_message_inbox", return_value=inbox_mock - ) - # patch this so we can make sure feature enabled isn't modified - mqtt_pipeline.feature_enabled.__getitem__.return_value = False - - client._receive_type = RECEIVE_TYPE_HANDLER - # Error was raised - with pytest.raises(client_exceptions.ClientError): - client.receive_message_on_input(input_name="some_input", block=block, timeout=timeout) - # Feature was not enabled - assert mqtt_pipeline.enable_feature.call_count == 0 - # Inbox get was not called - assert inbox_mock.get.call_count == 0 - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .receive_method_request()") -class TestIoTHubModuleClientReceiveMethodRequest( - IoTHubModuleClientTestsConfig, SharedClientReceiveMethodRequestTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .send_method_response()") -class TestIoTHubModuleClientSendMethodResponse( - IoTHubModuleClientTestsConfig, SharedClientSendMethodResponseTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .get_twin()") -class TestIoTHubModuleClientGetTwin(IoTHubModuleClientTestsConfig, SharedClientGetTwinTests): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .patch_twin_reported_properties()") -class TestIoTHubModuleClientPatchTwinReportedProperties( - IoTHubModuleClientTestsConfig, SharedClientPatchTwinReportedPropertiesTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .receive_twin_desired_properties_patch()") -class TestIoTHubModuleClientReceiveTwinDesiredPropertiesPatch( - IoTHubModuleClientTestsConfig, SharedClientReceiveTwinDesiredPropertiesPatchTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .invoke_method()") -class TestIoTHubModuleClientInvokeMethod(WaitsForEventCompletion, IoTHubModuleClientTestsConfig): - @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a device") - def test_calls_pipeline_invoke_method_for_device(self, client, http_pipeline): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - client.invoke_method(method_params, device_id) - assert http_pipeline.invoke_method.call_count == 1 - assert http_pipeline.invoke_method.call_args[0][0] is device_id - assert http_pipeline.invoke_method.call_args[0][1] is method_params - - @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a module") - def test_calls_pipeline_invoke_method_for_module(self, client, http_pipeline): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - module_id = "__fake_module_id__" - client.invoke_method(method_params, device_id, module_id=module_id) - assert http_pipeline.invoke_method.call_count == 1 - assert http_pipeline.invoke_method.call_args[0][0] is device_id - assert http_pipeline.invoke_method.call_args[0][1] is method_params - assert http_pipeline.invoke_method.call_args[1]["module_id"] is module_id - - @pytest.mark.it( - "Waits for the completion of the 'invoke_method' pipeline operation before returning" - ) - def test_waits_for_pipeline_op_completion( - self, mocker, client_manual_cb, http_pipeline_manual_cb - ): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - module_id = "__fake_module_id__" - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=http_pipeline_manual_cb.invoke_method, - kwargs={"invoke_method_response": "__fake_invoke_method_response__"}, - ) - - client_manual_cb.invoke_method(method_params, device_id, module_id=module_id) - - @pytest.mark.it( - "Raises a client error if the `invoke_method` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationCancelled, - client_exceptions.OperationCancelled, - id="OperationCancelled -> OperationCancelled", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, http_pipeline_manual_cb, pipeline_error, client_error - ): - method_params = {"methodName": "__fake_method_name__"} - device_id = "__fake_device_id__" - module_id = "__fake_module_id__" - my_pipeline_error = pipeline_error() - self.add_event_completion_checks( - mocker=mocker, - pipeline_function=http_pipeline_manual_cb.invoke_method, - kwargs={"error": my_pipeline_error}, - ) - with pytest.raises(client_error) as e_info: - client_manual_cb.invoke_method(method_params, device_id, module_id=module_id) - assert e_info.value.__cause__ is my_pipeline_error - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - PROPERTY .on_message_received") -class TestIoTHubModuleClientPROPERTYOnMessageReceivedHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_message_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.INPUT_MSG - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - PROPERTY .on_method_request_received") -class TestIoTHubModuleClientPROPERTYOnMethodRequestReceivedHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_method_request_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.METHODS - - -@pytest.mark.describe( - "IoTHubModuleClient (Synchronous) - PROPERTY .on_twin_desired_properties_patch_received" -) -class TestIoTHubModuleClientPROPERTYOnTwinDesiredPropertiesPatchReceivedHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYReceiverHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_twin_desired_properties_patch_received" - - @pytest.fixture - def feature_name(self): - return pipeline_constant.TWIN_PATCHES - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - PROPERTY .on_connection_state_change") -class TestIoTHubModuleClientPROPERTYOnConnectionStateChangeHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_connection_state_change" - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - PROPERTY .on_new_sastoken_required") -class TestIoTHubModuleClientPROPERTYOnNewSastokenRequiredHandler( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_new_sastoken_required" - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - PROPERTY .on_background_exception") -class TestIoTHubModuleClientPROPERTYOnBackgroundExceptionHandler( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYHandlerTests -): - @pytest.fixture - def handler_name(self): - return "on_background_exception" - - -@pytest.mark.describe("IoTHubModule (Synchronous) - PROPERTY .connected") -class TestIoTHubModuleClientPROPERTYConnected( - IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - OCCURRENCE: Connect") -class TestIoTHubModuleClientOCCURRENCEConnect( - IoTHubModuleClientTestsConfig, SharedIoTHubClientOCCURRENCEConnectTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - OCCURRENCE: Disconnect") -class TestIoTHubModuleClientOCCURRENCEDisconnect( - IoTHubModuleClientTestsConfig, SharedIoTHubClientOCCURRENCEDisconnectTests -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - OCCURRENCE: New Sastoken Required") -class TestIoTHubModuleClientOCCURRENCENewSastokenRequired( - IoTHubModuleClientTestsConfig, SharedIoTHubClientOCCURRENCENewSastokenRequired -): - pass - - -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - OCCURRENCE: Background Exception") -class TestIoTHubModuleClientOCCURRENCEBackgroundException( - IoTHubDeviceClientTestsConfig, SharedIoTHubClientOCCURRENCEBackgroundException -): - pass diff --git a/tests/unit/iothub/test_sync_handler_manager.py b/tests/unit/iothub/test_sync_handler_manager.py deleted file mode 100644 index 3bdf24920..000000000 --- a/tests/unit/iothub/test_sync_handler_manager.py +++ /dev/null @@ -1,724 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import threading -import time -from azure.iot.device.common import handle_exceptions -from azure.iot.device.iothub import client_event -from azure.iot.device.iothub.sync_handler_manager import SyncHandlerManager, HandlerManagerException -from azure.iot.device.iothub.sync_handler_manager import MESSAGE, METHOD, TWIN_DP_PATCH -from azure.iot.device.iothub.inbox_manager import InboxManager -from azure.iot.device.iothub.sync_inbox import SyncClientInbox - -logging.basicConfig(level=logging.DEBUG) - -# NOTE ON TEST IMPLEMENTATION: -# Despite having significant shared implementation between the sync and async handler managers, -# there are not shared tests. This is because while both have the same set of requirements and -# APIs, the internal implementation is different to an extent that it simply isn't really possible -# to test them to an appropriate degree of correctness with a shared set of tests. -# This means we must be very careful to always change both test modules when a change is made to -# shared behavior, or when shared features are added. - -# NOTE ON TIMING/DELAY -# Several tests in this module have sleeps/delays in their implementation due to needing to wait -# for things to happen in other threads. - -all_internal_receiver_handlers = [MESSAGE, METHOD, TWIN_DP_PATCH] -all_internal_client_event_handlers = [ - "_on_connection_state_change", - "_on_new_sastoken_required", - "_on_background_exception", -] -all_internal_handlers = all_internal_receiver_handlers + all_internal_client_event_handlers -all_receiver_handlers = [s.lstrip("_") for s in all_internal_receiver_handlers] -all_client_event_handlers = [s.lstrip("_") for s in all_internal_client_event_handlers] -all_handlers = all_receiver_handlers + all_client_event_handlers - - -class ThreadsafeMock(object): - """This class provides (some) Mock functionality in a threadsafe manner, specifically, it - ensures that the 'call_count' attribute will be accurate when the mock is called from another - thread. - - It does not cover ALL mock functionality, but more features could be added to it as necessary - """ - - def __init__(self): - self.call_count = 0 - self.lock = threading.Lock() - - def __call__(self, *args, **kwargs): - with self.lock: - self.call_count += 1 - - -@pytest.fixture -def inbox_manager(mocker): - return InboxManager(inbox_type=SyncClientInbox) - - -@pytest.fixture -def handler(): - def some_handler_fn(arg): - pass - - return some_handler_fn - - -@pytest.mark.describe("SyncHandlerManager - Instantiation") -class TestInstantiation(object): - @pytest.mark.it("Initializes handler properties to None") - @pytest.mark.parametrize("handler_name", all_handlers) - def test_handlers(self, inbox_manager, handler_name): - hm = SyncHandlerManager(inbox_manager) - assert getattr(hm, handler_name) is None - - @pytest.mark.it("Initializes receiver handler runner thread references to None") - @pytest.mark.parametrize( - "handler_name", all_internal_receiver_handlers, ids=all_receiver_handlers - ) - def test_receiver_handler_runners(self, inbox_manager, handler_name): - hm = SyncHandlerManager(inbox_manager) - assert hm._receiver_handler_runners[handler_name] is None - - @pytest.mark.it("Initializes client event handler runner thread reference to None") - def test_client_event_handler_runner(self, inbox_manager): - hm = SyncHandlerManager(inbox_manager) - assert hm._client_event_runner is None - - -@pytest.mark.describe("SyncHandlerManager - .stop()") -class TestStop(object): - @pytest.fixture( - params=[ - "No handlers running", - "Some receiver handlers running", - "Some client event handlers running", - "Some receiver and some client event handlers running", - "All handlers running", - ] - ) - def handler_manager(self, request, inbox_manager, handler): - hm = SyncHandlerManager(inbox_manager) - if request.param == "Some receiver handlers running": - # Set an arbitrary receiver handler - hm.on_message_received = handler - elif request.param == "Some client event handlers running": - # Set an arbitrary client event handler - hm.on_connection_state_change = handler - elif request.param == "Some receiver and some client event handlers running": - # Set an arbitrary receiver and client event handler - hm.on_message_received = handler - hm.on_connection_state_change = handler - elif request.param == "All handlers running": - # NOTE: this sets all handlers to be the same fn, but this doesn't really - # make a difference in this context - for handler_name in all_handlers: - setattr(hm, handler_name, handler) - yield hm - hm.stop() - - @pytest.mark.it("Stops all currently running handlers") - def test_stop_all(self, handler_manager): - handler_manager.stop() - for handler_name in all_internal_receiver_handlers: - assert handler_manager._receiver_handler_runners[handler_name] is None - assert handler_manager._client_event_runner is None - - @pytest.mark.it( - "Stops only the currently running receiver handlers if the 'receiver_handlers_only' parameter is True" - ) - def test_stop_only_receiver_handlers(self, handler_manager): - if handler_manager._client_event_runner is not None: - client_event_handlers_running = True - else: - client_event_handlers_running = False - - handler_manager.stop(receiver_handlers_only=True) - - # All receiver handlers have stopped - for handler_name in all_internal_receiver_handlers: - assert handler_manager._receiver_handler_runners[handler_name] is None - # If the client event handlers were running, they are STILL running - if client_event_handlers_running: - assert handler_manager._client_event_runner is not None - - @pytest.mark.it("Completes all pending handler invocations before stopping the runner(s)") - def test_completes_pending(self, mocker, inbox_manager): - hm = SyncHandlerManager(inbox_manager) - - # NOTE: We use two handlers arbitrarily here to show this happens for all handler runners - mock_msg_handler = ThreadsafeMock() - mock_mth_handler = ThreadsafeMock() - msg_inbox = inbox_manager.get_unified_message_inbox() - mth_inbox = inbox_manager.get_method_request_inbox() - for _ in range(200): # sufficiently many items so can't complete quickly - msg_inbox.put(mocker.MagicMock()) - mth_inbox.put(mocker.MagicMock()) - - hm.on_message_received = mock_msg_handler - hm.on_method_request_received = mock_mth_handler - assert mock_msg_handler.call_count < 200 - assert mock_mth_handler.call_count < 200 - hm.stop() - time.sleep(0.1) - assert mock_msg_handler.call_count == 200 - assert mock_mth_handler.call_count == 200 - assert msg_inbox.empty() - assert mth_inbox.empty() - - -@pytest.mark.describe("SyncHandlerManager - .ensure_running()") -class TestEnsureRunning(object): - @pytest.fixture( - params=[ - "All handlers set, all stopped", - "All handlers set, receivers stopped, client events running", - "All handlers set, all running", - "Some receiver and client event handlers set, all stopped", - "Some receiver and client event handlers set, receivers stopped, client events running", - "Some receiver and client event handlers set, all running", - "Some receiver handlers set, all stopped", - "Some receiver handlers set, all running", - "Some client event handlers set, all stopped", - "Some client event handlers set, all running", - "No handlers set", - ] - ) - def handler_manager(self, request, inbox_manager, handler): - # NOTE: this sets all handlers to be the same fn, but this doesn't really - # make a difference in this context - hm = SyncHandlerManager(inbox_manager) - - if request.param == "All handlers set, all stopped": - for handler_name in all_handlers: - setattr(hm, handler_name, handler) - hm.stop() - elif request.param == "All handlers set, receivers stopped, client events running": - for handler_name in all_handlers: - setattr(hm, handler_name, handler) - hm.stop(receiver_handlers_only=True) - elif request.param == "All handlers set, all running": - for handler_name in all_handlers: - setattr(hm, handler_name, handler) - elif request.param == "Some receiver and client event handlers set, all stopped": - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - hm.stop() - elif ( - request.param - == "Some receiver and client event handlers set, receivers stopped, client events running" - ): - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - hm.stop(receiver_handlers_only=True) - elif request.param == "Some receiver and client event handlers set, all running": - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - elif request.param == "Some receiver handlers set, all stopped": - hm.on_message_received = handler - hm.on_method_request_received = handler - hm.stop() - elif request.param == "Some receiver handlers set, all running": - hm.on_message_received = handler - hm.on_method_request_received = handler - elif request.param == "Some client event handlers set, all stopped": - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - hm.stop() - elif request.param == "Some client event handlers set, all running": - hm.on_connection_state_change = handler - hm.on_new_sastoken_required = handler - - yield hm - hm.stop() - - @pytest.mark.it( - "Starts handler runners for any handler that is set, but does not have a handler runner running" - ) - def test_starts_runners_if_necessary(self, handler_manager): - handler_manager.ensure_running() - - # Check receiver handlers - for handler_name in all_receiver_handlers: - if getattr(handler_manager, handler_name) is not None: - # NOTE: this assumes the convention of internal names being the name of a handler - # prefixed with a "_". If this ever changes, you must change this test. - assert handler_manager._receiver_handler_runners["_" + handler_name] is not None - - # Check client event handlers - for handler_name in all_client_event_handlers: - if getattr(handler_manager, handler_name) is not None: - assert handler_manager._client_event_runner is not None - # don't need to check the rest of the handlers since they all share a runner - break - - -# ############## -# # PROPERTIES # -# ############## - - -class SharedHandlerPropertyTests(object): - @pytest.fixture - def handler_manager(self, inbox_manager): - hm = SyncHandlerManager(inbox_manager) - yield hm - hm.stop() - - # NOTE: We use setattr() and getattr() in these tests so they're generic to all properties. - # This is functionally identical to doing explicit assignment to a property, it just - # doesn't read quite as well. - - @pytest.mark.it("Can be both read and written to") - def test_read_write(self, handler_name, handler_manager, handler): - assert getattr(handler_manager, handler_name) is None - setattr(handler_manager, handler_name, handler) - assert getattr(handler_manager, handler_name) is handler - setattr(handler_manager, handler_name, None) - assert getattr(handler_manager, handler_name) is None - - -class SharedReceiverHandlerPropertyTests(SharedHandlerPropertyTests): - # NOTE: If there is ever any deviation in the convention of what the internal names of handlers - # are other than just a prefixed "_", we'll have to move this fixture to the child classes so - # it can be unique to each handler - @pytest.fixture - def handler_name_internal(self, handler_name): - return "_" + handler_name - - @pytest.mark.it( - "Creates and starts a daemon Thread for the corresponding handler runner when value is set to a function" - ) - def test_thread_created(self, handler_name, handler_name_internal, handler_manager, handler): - assert handler_manager._receiver_handler_runners[handler_name_internal] is None - setattr(handler_manager, handler_name, handler) - assert isinstance( - handler_manager._receiver_handler_runners[handler_name_internal], threading.Thread - ) - assert handler_manager._receiver_handler_runners[handler_name_internal].daemon is True - - @pytest.mark.it( - "Stops the corresponding handler runner and completes any existing daemon Thread for it when the value is set back to None" - ) - def test_thread_removed(self, handler_name, handler_name_internal, handler_manager, handler): - # Set handler - setattr(handler_manager, handler_name, handler) - # Thread has been created and is alive - t = handler_manager._receiver_handler_runners[handler_name_internal] - assert isinstance(t, threading.Thread) - assert t.is_alive() - # Set the handler back to None - setattr(handler_manager, handler_name, None) - # Thread has finished and the manager no longer has a reference to it - assert not t.is_alive() - assert handler_manager._receiver_handler_runners[handler_name_internal] is None - - @pytest.mark.it( - "Does not delete, remove, or replace the Thread for the corresponding handler runner, when updated with a new function value" - ) - def test_thread_unchanged_by_handler_update( - self, handler_name, handler_name_internal, handler_manager, handler - ): - # Set the handler - setattr(handler_manager, handler_name, handler) - # Thread has been crated and is alive - t = handler_manager._receiver_handler_runners[handler_name_internal] - assert isinstance(t, threading.Thread) - assert t.is_alive() - - # Set new handler - def new_handler(arg): - pass - - setattr(handler_manager, handler_name, new_handler) - assert handler_manager._receiver_handler_runners[handler_name_internal] is t - assert t.is_alive() - - @pytest.mark.it( - "Is invoked by the runner when the Inbox corresponding to the handler receives an object, passing that object to the handler" - ) - def test_handler_invoked(self, mocker, handler_name, handler_manager, inbox): - # Set the handler - mock_handler = mocker.MagicMock() - setattr(handler_manager, handler_name, mock_handler) - # Handler has not been called - assert mock_handler.call_count == 0 - - # Add an item to corresponding inbox, triggering the handler - mock_obj = mocker.MagicMock() - inbox.put(mock_obj) - time.sleep(0.1) - - # Handler has been called with the item from the inbox - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(mock_obj) - - @pytest.mark.it( - "Is invoked by the runner every time the Inbox corresponding to the handler receives an object" - ) - def test_handler_invoked_multiple(self, mocker, handler_name, handler_manager, inbox): - # Set the handler - mock_handler = ThreadsafeMock() - setattr(handler_manager, handler_name, mock_handler) - # Handler has not been called - assert mock_handler.call_count == 0 - - # Add 5 items to the corresponding inbox, triggering the handler - for _ in range(5): - inbox.put(mocker.MagicMock()) - time.sleep(0.2) - - # Handler has been called 5 times - assert mock_handler.call_count == 5 - - @pytest.mark.it( - "Is invoked for every item already in the corresponding Inbox at the moment of handler removal" - ) - def test_handler_resolve_pending_items_before_handler_removal( - self, mocker, handler_name, handler_manager, inbox - ): - # Use a threadsafe mock to ensure accurate counts - mock_handler = ThreadsafeMock() - assert inbox.empty() - # Queue up a bunch of items in the inbox - for _ in range(100): - inbox.put(mocker.MagicMock()) - # The handler has not yet been called - assert mock_handler.call_count == 0 - # Items are still in the inbox - assert not inbox.empty() - # Set the handler - setattr(handler_manager, handler_name, mock_handler) - # The handler has not yet been called for everything that was in the inbox - # NOTE: I'd really like to show that the handler call count is also > 0 here, but - # it's pretty difficult to make the timing work - assert mock_handler.call_count < 100 - - # Immediately remove the handler - setattr(handler_manager, handler_name, None) - # Wait to give a chance for the handler runner to finish calling everything - time.sleep(0.2) - # Despite removal, handler has been called for everything that was in the inbox at the - # time of the removal - assert mock_handler.call_count == 100 - assert inbox.empty() - - # Add some more items - for _ in range(100): - inbox.put(mocker.MagicMock()) - # Wait to give a chance for the handler to be called (it won't) - time.sleep(0.2) - # Despite more items added to inbox, no further handler calls have been made beyond the - # initial calls that were made when the original items were added - assert mock_handler.call_count == 100 - - @pytest.mark.it( - "Sends a HandlerManagerException to the background exception handler if any exception is raised during its invocation" - ) - def test_exception_in_handler( - self, mocker, handler_name, handler_manager, inbox, arbitrary_exception - ): - background_exc_spy = mocker.spy(handle_exceptions, "handle_background_exception") - # Handler will raise exception when called - mock_handler = mocker.MagicMock() - mock_handler.side_effect = arbitrary_exception - # Set handler - setattr(handler_manager, handler_name, mock_handler) - # Handler has not been called - assert mock_handler.call_count == 0 - # Background exception handler has not been called - assert background_exc_spy.call_count == 0 - # Add an item to corresponding inbox, triggering the handler - inbox.put(mocker.MagicMock()) - time.sleep(0.1) - # Handler has now been called - assert mock_handler.call_count == 1 - # Background exception handler was called - assert background_exc_spy.call_count == 1 - e = background_exc_spy.call_args[0][0] - assert isinstance(e, HandlerManagerException) - assert e.__cause__ is arbitrary_exception - - @pytest.mark.it( - "Can be updated with a new value that the corresponding handler runner will immediately begin using for handler invocations instead" - ) - def test_handler_update_handler(self, mocker, handler_name, handler_manager, inbox): - def handler(arg): - # Invoking handler replaces the set handler with a mock - setattr(handler_manager, handler_name, mocker.MagicMock()) - - setattr(handler_manager, handler_name, handler) - - inbox.put(mocker.MagicMock()) - time.sleep(0.1) - # Handler has been replaced with a mock, but the mock has not been invoked - assert getattr(handler_manager, handler_name) is not handler - assert getattr(handler_manager, handler_name).call_count == 0 - # Add a new item to the inbox - inbox.put(mocker.MagicMock()) - time.sleep(0.1) - # The mock was now called - assert getattr(handler_manager, handler_name).call_count == 1 - - -class SharedClientEventHandlerPropertyTests(SharedHandlerPropertyTests): - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_client_event_inbox() - - @pytest.mark.it( - "Creates and starts a daemon Thread for the Client Event handler runner when value is set to a function if the Client Event handler runner does not already exist" - ) - def test_no_client_event_runner(self, handler_name, handler_manager, handler): - assert handler_manager._client_event_runner is None - setattr(handler_manager, handler_name, handler) - t = handler_manager._client_event_runner - assert isinstance(t, threading.Thread) - assert t.daemon is True - - @pytest.mark.it( - "Does not modify the Client Event handler runner thread when value is set to a function if the Client Event handler runner already exists" - ) - def test_client_event_runner_already_exists(self, handler_name, handler_manager, handler): - # Add a fake client event runner thread - fake_runner_thread = threading.Thread() - fake_runner_thread.daemon = True - fake_runner_thread.start() - handler_manager._client_event_runner = fake_runner_thread - # Set handler - setattr(handler_manager, handler_name, handler) - # Fake thread was not changed - assert handler_manager._client_event_runner is fake_runner_thread - - @pytest.mark.it( - "Does not delete, remove, or replace the Thread for the Client Event handler runner when value is set back to None" - ) - def test_handler_removed(self, handler_name, handler_manager, handler): - # Set handler - setattr(handler_manager, handler_name, handler) - # Thread has been created and is alive - t = handler_manager._client_event_runner - assert isinstance(t, threading.Thread) - assert t.is_alive() - # Set the handler back to None - setattr(handler_manager, handler_name, None) - # Thread is still maintained on the manager and alive - assert handler_manager._client_event_runner is t - assert t.is_alive() - - @pytest.mark.it( - "Does not delete, remove, or replace the Thread for the Client Event handler runner when updated with a new function value" - ) - def test_handler_update(self, handler_name, handler_manager, handler): - # Set handler - setattr(handler_manager, handler_name, handler) - # Thread has been created and is alive - t = handler_manager._client_event_runner - assert isinstance(t, threading.Thread) - assert t.is_alive() - - # Set new handler - def new_handler(arg): - pass - - setattr(handler_manager, handler_name, new_handler) - - # Thread is still maintained on the manager and alive - assert handler_manager._client_event_runner is t - assert t.is_alive() - - @pytest.mark.it( - "Is invoked by the runner only when the Client Event Inbox receives a matching Client Event, passing any arguments to the handler" - ) - def test_handler_invoked(self, mocker, handler_name, handler_manager, inbox, event): - # Set the handler - mock_handler = mocker.MagicMock() - setattr(handler_manager, handler_name, mock_handler) - # Handler has not been called - assert mock_handler.call_count == 0 - - # Add the event to the client event inbox - inbox.put(event) - time.sleep(0.1) - - # Handler has been called with the arguments from the event - assert mock_handler.call_count == 1 - assert mock_handler.call_args == mocker.call(*event.args_for_user) - - # Add non-matching event to the client event inbox - non_matching_event = client_event.ClientEvent("NON_MATCHING_EVENT") - inbox.put(non_matching_event) - time.sleep(0.1) - - # Handler has not been called again - assert mock_handler.call_count == 1 - - @pytest.mark.it( - "Is invoked by the runner every time the Client Event Inbox receives a matching Client Event" - ) - def test_handler_invoked_multiple(self, handler_name, handler_manager, inbox, event): - # Set the handler - mock_handler = ThreadsafeMock() - setattr(handler_manager, handler_name, mock_handler) - # Handler has not been called - assert mock_handler.call_count == 0 - - # Add 5 matching events to the corresponding inbox, triggering the handler - for _ in range(5): - inbox.put(event) - time.sleep(0.2) - - # Handler has been called 5 times - assert mock_handler.call_count == 5 - - @pytest.mark.it( - "Sends a HandlerManagerException to the background exception handler if any exception is raised during its invocation" - ) - def test_exception_in_handler( - self, mocker, handler_name, handler_manager, inbox, event, arbitrary_exception - ): - background_exc_spy = mocker.spy(handle_exceptions, "handle_background_exception") - # Handler will raise exception when called - mock_handler = mocker.MagicMock() - mock_handler.side_effect = arbitrary_exception - # Set handler - setattr(handler_manager, handler_name, mock_handler) - # Handler has not been called - assert mock_handler.call_count == 0 - # Background exception handler has not been called - assert background_exc_spy.call_count == 0 - # Add the event to the client event inbox, triggering the handler - inbox.put(event) - time.sleep(0.1) - # Handler has now been called - assert mock_handler.call_count == 1 - # Background exception handler was called - assert background_exc_spy.call_count == 1 - e = background_exc_spy.call_args[0][0] - assert isinstance(e, HandlerManagerException) - assert e.__cause__ is arbitrary_exception - - @pytest.mark.it( - "Can be updated with a new value that the Client Event handler runner will immediately begin using for handler invocations instead" - ) - def test_updated_handler(self, mocker, handler_name, handler_manager, inbox, event): - def handler(*args): - # Invoking handler replaces the set handler with a mock - setattr(handler_manager, handler_name, mocker.MagicMock()) - - setattr(handler_manager, handler_name, handler) - - inbox.put(event) - time.sleep(0.1) - # Handler has been replaced with a mock, but the mock has not been invoked - assert getattr(handler_manager, handler_name) is not handler - assert getattr(handler_manager, handler_name).call_count == 0 - # Add a new event to the inbox - inbox.put(event) - time.sleep(0.1) - # The mock was now called - assert getattr(handler_manager, handler_name).call_count == 1 - - -@pytest.mark.describe("SyncHandlerManager - PROPERTY: .on_message_received") -class TestSyncHandlerManagerPropertyOnMessageReceived(SharedReceiverHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_message_received" - - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_unified_message_inbox() - - -@pytest.mark.describe("SyncHandlerManager - PROPERTY: .on_method_request_received") -class TestSyncHandlerManagerPropertyOnMethodRequestReceived(SharedReceiverHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_method_request_received" - - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_method_request_inbox() - - -@pytest.mark.describe("SyncHandlerManager - PROPERTY: .on_twin_desired_properties_patch_received") -class TestSyncHandlerManagerPropertyOnTwinDesiredPropertiesPatchReceived( - SharedReceiverHandlerPropertyTests -): - @pytest.fixture - def handler_name(self): - return "on_twin_desired_properties_patch_received" - - @pytest.fixture - def inbox(self, inbox_manager): - return inbox_manager.get_twin_patch_inbox() - - -@pytest.mark.describe("SyncHandlerManager - PROPERTY: .on_connection_state_change") -class TestSyncHandlerManagerPropertyOnConnectionStateChange(SharedClientEventHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_connection_state_change" - - @pytest.fixture - def event(self): - return client_event.ClientEvent(client_event.CONNECTION_STATE_CHANGE) - - -@pytest.mark.describe("SyncHandlerManager - PROPERTY: .on_new_sastoken_required") -class TestSyncHandlerManagerPropertyOnNewSastokenRequired(SharedClientEventHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_new_sastoken_required" - - @pytest.fixture - def event(self): - return client_event.ClientEvent(client_event.NEW_SASTOKEN_REQUIRED) - - -@pytest.mark.describe("SyncHandlerManager - PROPERTY: .on_background_exception") -class TestSyncHandlerManagerPropertyOnBackgroundException(SharedClientEventHandlerPropertyTests): - @pytest.fixture - def handler_name(self): - return "on_background_exception" - - @pytest.fixture - def event(self, arbitrary_exception): - return client_event.ClientEvent(client_event.BACKGROUND_EXCEPTION, arbitrary_exception) - - -@pytest.mark.describe("SyncHandlerManager - PROPERTY: .handling_client_events") -class TestSyncHandlerManagerPropertyHandlingClientEvents(object): - @pytest.fixture - def handler_manager(self, inbox_manager): - hm = SyncHandlerManager(inbox_manager) - yield hm - hm.stop() - - @pytest.mark.it("Is True if the Client Event Handler Runner is running") - def test_client_event_runner_running(self, handler_manager): - # Add a fake client event runner thread - fake_runner_thread = threading.Thread() - fake_runner_thread.daemon = True - fake_runner_thread.start() - handler_manager._client_event_runner = fake_runner_thread - - assert handler_manager.handling_client_events is True - - @pytest.mark.it("Is False if the Client Event Handler Runner is not running") - def test_client_event_runner_not_running(self, handler_manager): - assert handler_manager._client_event_runner is None - assert handler_manager.handling_client_events is False diff --git a/tests/unit/iothub/test_sync_inbox.py b/tests/unit/iothub/test_sync_inbox.py deleted file mode 100644 index cfecdfc9f..000000000 --- a/tests/unit/iothub/test_sync_inbox.py +++ /dev/null @@ -1,134 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -import threading -import time -from azure.iot.device.iothub.sync_inbox import SyncClientInbox, InboxEmpty - -logging.basicConfig(level=logging.DEBUG) - - -@pytest.mark.describe("SyncClientInbox") -class TestSyncClientInbox(object): - @pytest.mark.it("Instantiates empty") - def test_instantiates_empty(self): - inbox = SyncClientInbox() - assert inbox.empty() - - @pytest.mark.it("Can be checked regarding whether or not it contains an item") - def test_check_item_is_in_inbox(self, mocker): - inbox = SyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - assert item not in inbox - inbox.put(item) - assert item in inbox - - @pytest.mark.it("Can checked regarding whether or not it is empty") - def test_check_if_empty(self, mocker): - inbox = SyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - inbox.put(item) - assert not inbox.empty() - inbox.get() - assert inbox.empty() - - @pytest.mark.it("Operates according to FIFO") - def test_operates_according_to_FIFO(self, mocker): - inbox = SyncClientInbox() - item1 = mocker.MagicMock() - item2 = mocker.MagicMock() - item3 = mocker.MagicMock() - inbox.put(item1) - inbox.put(item2) - inbox.put(item3) - - assert inbox.get() is item1 - assert inbox.get() is item2 - assert inbox.get() is item3 - - -@pytest.mark.describe("SyncClientInbox - .put()") -class TestSyncClientInboxPut(object): - @pytest.mark.it("Adds the given item to the inbox") - def test_adds_item_to_inbox(self, mocker): - inbox = SyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - inbox.put(item) - assert not inbox.empty() - assert item in inbox - - -@pytest.mark.describe("SyncClientInbox - .get()") -class TestSyncClientInboxGet(object): - @pytest.mark.it("Returns and removes the next item from the inbox, if there is one") - def test_removes_item_from_inbox_if_already_there(self, mocker): - inbox = SyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - inbox.put(item) - assert not inbox.empty() - retrieved_item = inbox.get() - assert retrieved_item is item - assert inbox.empty() - - @pytest.mark.it( - "Blocks on an empty inbox until an item is available to remove and return, if using blocking mode" - ) - def test_waits_for_item_to_be_added_if_inbox_empty_in_blocking_mode(self, mocker): - inbox = SyncClientInbox() - assert inbox.empty() - item = mocker.MagicMock() - - def insert_item(): - time.sleep(0.01) # wait before inserting - inbox.put(item) - - insertion_thread = threading.Thread(target=insert_item) - insertion_thread.start() - - retrieved_item = inbox.get(block=True) - assert retrieved_item is item - assert inbox.empty() - - @pytest.mark.it( - "Raises InboxEmpty exception after a timeout while blocking on an empty inbox, if a timeout is specified" - ) - def test_times_out_while_blocking_if_timeout_specified(self, mocker): - inbox = SyncClientInbox() - assert inbox.empty() - with pytest.raises(InboxEmpty): - inbox.get(block=True, timeout=0.01) - - @pytest.mark.it( - "Raises InboxEmpty exception if the inbox is empty, when using non-blocking mode" - ) - def test_get_raises_empty_if_inbox_empty_in_non_blocking_mode(self): - inbox = SyncClientInbox() - assert inbox.empty() - with pytest.raises(InboxEmpty): - inbox.get(block=False) - - -@pytest.mark.describe("SyncClientInbox - .clear()") -class TestSyncClientInboxClear(object): - @pytest.mark.it("Clears all items from the inbox") - def test_can_clear_all_items(self, mocker): - inbox = SyncClientInbox() - item1 = mocker.MagicMock() - item2 = mocker.MagicMock() - item3 = mocker.MagicMock() - inbox.put(item1) - inbox.put(item2) - inbox.put(item3) - assert not inbox.empty() - - inbox.clear() - assert inbox.empty() diff --git a/tests/unit/provisioning/__init__.py b/tests/unit/provisioning/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/provisioning/aio/__init__.py b/tests/unit/provisioning/aio/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/provisioning/aio/test_async_provisioning_device_client.py b/tests/unit/provisioning/aio/test_async_provisioning_device_client.py deleted file mode 100644 index 2b1106a52..000000000 --- a/tests/unit/provisioning/aio/test_async_provisioning_device_client.py +++ /dev/null @@ -1,365 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import logging -from azure.iot.device.provisioning.aio.async_provisioning_device_client import ( - ProvisioningDeviceClient, -) -from azure.iot.device.common import async_adapter -import asyncio -from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions -from azure.iot.device import exceptions as client_exceptions -from ..shared_client_tests import ( - SharedProvisioningClientInstantiationTests, - SharedProvisioningClientCreateFromSymmetricKeyTests, - SharedProvisioningClientCreateFromX509CertificateTests, -) - - -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.asyncio - - -async def create_completed_future(result=None): - f = asyncio.Future() - f.set_result(result) - return f - - -class ProvisioningClientTestsConfig(object): - """Defines fixtures for asynchronous ProvisioningDeviceClient tests""" - - @pytest.fixture - def client_class(self): - return ProvisioningDeviceClient - - @pytest.fixture - def client(self, provisioning_pipeline): - return ProvisioningDeviceClient(provisioning_pipeline) - - -@pytest.mark.describe("ProvisioningDeviceClient (Async) - Instantiation") -class TestProvisioningClientInstantiation( - ProvisioningClientTestsConfig, SharedProvisioningClientInstantiationTests -): - pass - - -@pytest.mark.describe("ProvisioningDeviceClient (Async) - .create_from_symmetric_key()") -class TestProvisioningClientCreateFromSymmetricKey( - ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromSymmetricKeyTests -): - pass - - -@pytest.mark.describe("ProvisioningDeviceClient (Async) - .create_from_x509_certificate()") -class TestProvisioningClientCreateFromX509Certificate( - ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromX509CertificateTests -): - pass - - -@pytest.mark.describe("ProvisioningDeviceClient (Async) - .register()") -class TestClientRegister(object): - @pytest.mark.it("Implicitly enables responses from provisioning service if not already enabled") - async def test_enables_provisioning_only_if_not_already_enabled( - self, mocker, provisioning_pipeline, registration_result - ): - # Override callback to pass successful result - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - - provisioning_pipeline.responses_enabled.__getitem__.return_value = False - - client = ProvisioningDeviceClient(provisioning_pipeline) - await client.register() - - assert provisioning_pipeline.enable_responses.call_count == 1 - - provisioning_pipeline.enable_responses.reset_mock() - - provisioning_pipeline.responses_enabled.__getitem__.return_value = True - await client.register() - assert provisioning_pipeline.enable_responses.call_count == 0 - - @pytest.mark.it("Begins a 'register' pipeline operation") - async def test_register_calls_pipeline_register( - self, provisioning_pipeline, mocker, registration_result - ): - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - client = ProvisioningDeviceClient(provisioning_pipeline) - await client.register() - assert provisioning_pipeline.register.call_count == 1 - - @pytest.mark.it( - "Begins a 'shutdown' pipeline operation if the registration result is successful" - ) - async def test_shutdown_upon_success(self, mocker, provisioning_pipeline, registration_result): - # success result - registration_result._status = "assigned" - - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - await client.register() - - assert provisioning_pipeline.shutdown.call_count == 1 - - @pytest.mark.it( - "Does NOT begin a 'shutdown' pipeline operation if the registration result is NOT successful" - ) - async def test_no_shutdown_upon_fail(self, mocker, provisioning_pipeline, registration_result): - # fail result - registration_result._status = "not assigned" - - def register_complete_fail_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_fail_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - await client.register() - - assert provisioning_pipeline.shutdown.call_count == 0 - - @pytest.mark.it( - "Waits for the completion of both the 'register' and 'shutdown' pipeline operations before returning, if the registration result is successful" - ) - async def test_waits_for_pipeline_op_completions_on_success( - self, mocker, provisioning_pipeline, registration_result - ): - # success result - registration_result._status = "assigned" - - # Set up mocks - cb_mock_register = mocker.MagicMock() - cb_mock_shutdown = mocker.MagicMock() - cb_mock_register.completion.return_value = await create_completed_future( - registration_result - ) - cb_mock_shutdown.completion.return_value = await create_completed_future(None) - mocker.patch.object(async_adapter, "AwaitableCallback").side_effect = [ - cb_mock_register, - cb_mock_shutdown, - ] - - # Run test - client = ProvisioningDeviceClient(provisioning_pipeline) - await client.register() - - # Calls made as expected - assert provisioning_pipeline.register.call_count == 1 - assert provisioning_pipeline.shutdown.call_count == 1 - # Callbacks sent to pipeline as expected - assert provisioning_pipeline.register.call_args == mocker.call( - payload=mocker.ANY, callback=cb_mock_register - ) - assert provisioning_pipeline.shutdown.call_args == mocker.call(callback=cb_mock_shutdown) - # Callback completions were waited upon as expected - assert cb_mock_register.completion.call_count == 1 - assert cb_mock_shutdown.completion.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of just the 'register' pipeline operation before returning, if the registration result is NOT successful" - ) - async def test_waits_for_pipeline_op_completion_on_failure( - self, mocker, provisioning_pipeline, registration_result - ): - # fail result - registration_result._status = "not assigned" - - # Set up mocks - cb_mock_register = mocker.MagicMock() - cb_mock_shutdown = mocker.MagicMock() - cb_mock_register.completion.return_value = await create_completed_future( - registration_result - ) - cb_mock_shutdown.completion.return_value = await create_completed_future(None) - mocker.patch.object(async_adapter, "AwaitableCallback").side_effect = [ - cb_mock_register, - cb_mock_shutdown, - ] - - # Run test - client = ProvisioningDeviceClient(provisioning_pipeline) - await client.register() - - # Calls made as expected - assert provisioning_pipeline.register.call_count == 1 - assert provisioning_pipeline.shutdown.call_count == 0 - # Callbacks sent to pipeline as expected - assert provisioning_pipeline.register.call_args == mocker.call( - payload=mocker.ANY, callback=cb_mock_register - ) - # Callback completions were waited upon as expected - assert cb_mock_register.completion.call_count == 1 - assert cb_mock_shutdown.completion.call_count == 0 - - @pytest.mark.it("Returns the registration result that the pipeline returned") - async def test_verifies_registration_result_returned( - self, mocker, provisioning_pipeline, registration_result - ): - result = registration_result - - def register_complete_success_callback(payload, callback): - callback(result=result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - result_returned = await client.register() - assert result_returned == result - - @pytest.mark.it( - "Raises a client error if the `register` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout->OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - async def test_raises_error_on_register_pipeline_op_error( - self, mocker, client_error, pipeline_error, provisioning_pipeline - ): - error = pipeline_error() - - def register_complete_failure_callback(payload, callback): - callback(result=None, error=error) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_failure_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - - with pytest.raises(client_error) as e_info: - await client.register() - - assert e_info.value.__cause__ is error - assert provisioning_pipeline.register.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `shutdown` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - # The only expected errors are unexpected ones - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError") - ], - ) - async def test_raises_error_on_shutdown_pipeline_op_error( - self, mocker, pipeline_error, client_error, provisioning_pipeline, registration_result - ): - # success result is required to trigger shutdown - registration_result._status = "assigned" - - error = pipeline_error() - - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - def shutdown_failure_callback(callback): - callback(result=None, error=error) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - mocker.patch.object( - provisioning_pipeline, "shutdown", side_effect=shutdown_failure_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - with pytest.raises(client_error) as e_info: - await client.register() - - assert e_info.value.__cause__ is error - assert provisioning_pipeline.register.call_count == 1 - - -@pytest.mark.describe("ProvisioningDeviceClient (Async) - .set_provisioning_payload()") -class TestClientProvisioningPayload(object): - @pytest.mark.it("Sets the payload on the provisioning payload attribute") - @pytest.mark.parametrize( - "payload_input", - [ - pytest.param("Hello World", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - async def test_set_payload(self, mocker, payload_input): - provisioning_pipeline = mocker.MagicMock() - - client = ProvisioningDeviceClient(provisioning_pipeline) - client.provisioning_payload = payload_input - assert client._provisioning_payload == payload_input - - @pytest.mark.it("Gets the payload from provisioning payload property") - @pytest.mark.parametrize( - "payload_input", - [ - pytest.param("Hello World", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - async def test_get_payload(self, mocker, payload_input): - provisioning_pipeline = mocker.MagicMock() - - client = ProvisioningDeviceClient(provisioning_pipeline) - client.provisioning_payload = payload_input - assert client.provisioning_payload == payload_input diff --git a/tests/unit/provisioning/conftest.py b/tests/unit/provisioning/conftest.py deleted file mode 100644 index c67b32f4e..000000000 --- a/tests/unit/provisioning/conftest.py +++ /dev/null @@ -1,20 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from .shared_client_fixtures import ( # noqa: F401 - mock_pipeline_init, - provisioning_pipeline, - registration_result, - x509, -) - - -fake_status = "FakeStatus" -fake_sub_status = "FakeSubStatus" -fake_operation_id = "fake_operation_id" -fake_request_id = "request_1234" -fake_device_id = "MyDevice" -fake_assigned_hub = "MyIoTHub" diff --git a/tests/unit/provisioning/models/test_registration_result.py b/tests/unit/provisioning/models/test_registration_result.py deleted file mode 100644 index aba1bb4be..000000000 --- a/tests/unit/provisioning/models/test_registration_result.py +++ /dev/null @@ -1,133 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -import datetime -from azure.iot.device.provisioning.models.registration_result import ( - RegistrationResult, - RegistrationState, -) -import json - -logging.basicConfig(level=logging.DEBUG) - -fake_request_id = "Request1234" -fake_operation_id = "Operation4567" -fake_status = "FakeStatus" -fake_device_id = "MyDevice" -fake_assigned_hub = "MyIoTHub" -fake_sub_status = "FakeSubStatus" -fake_created_dttm = datetime.datetime(2020, 5, 17) -fake_last_update_dttm = datetime.datetime(2020, 10, 17) -fake_etag = "SomeEtag" -fake_payload = "this is a fake payload" - - -@pytest.mark.describe("RegistrationResult") -class TestRegistrationResult(object): - @pytest.mark.it("Instantiates correctly") - def test_registration_result_instantiated_correctly(self): - fake_registration_state = create_registration_state() - registration_result = create_registration_result(fake_registration_state) - - assert registration_result.operation_id == fake_operation_id - assert registration_result.status == fake_status - assert registration_result.registration_state == fake_registration_state - - assert registration_result.registration_state.device_id == fake_device_id - assert registration_result.registration_state.assigned_hub == fake_assigned_hub - assert registration_result.registration_state.sub_status == fake_sub_status - assert registration_result.registration_state.created_date_time == fake_created_dttm - assert registration_result.registration_state.last_update_date_time == fake_last_update_dttm - assert registration_result.registration_state.etag == fake_etag - - @pytest.mark.it("Has a to string representation composed of registration state and status") - def test_registration_result_to_string(self): - fake_registration_state = create_registration_state() - registration_result = create_registration_result(fake_registration_state) - - string_repr = "\n".join([str(fake_registration_state), fake_status]) - assert str(registration_result) == string_repr - - @pytest.mark.parametrize( - "input_setter_code", - [ - pytest.param('registration_result.operation_id = "NewOperationId"', id="Operation Id"), - pytest.param('registration_result.status = "NewStatus"', id="Status"), - pytest.param( - 'registration_result.registration_state = "NewRegistrationState"', - id="Registration State", - ), - ], - ) - @pytest.mark.it("Has attributes that do not have setter") - def test_some_properties_of_result_are_not_settable(self, input_setter_code): - registration_result = create_registration_result() # noqa: F841 - with pytest.raises(AttributeError, match="can't set attribute"): - exec(input_setter_code) - - @pytest.mark.parametrize( - "input_setter_code", - [ - pytest.param('registration_state.device_id = "NewDeviceId"', id="Device Id"), - pytest.param('registration_state.assigned_hub = "NewHub"', id="Assigned Hub"), - pytest.param('registration_state.sub_status = "NewSubStatus"', id="Substatus"), - pytest.param('registration_state.etag = "NewEtag"', id="Etag"), - pytest.param( - "registration_state.created_date_time = datetime.datetime(3000, 10, 17)", - id="Create Date Time", - ), - pytest.param( - "registration_state.last_update_date_time = datetime.datetime(3000, 10, 17)", - id="Last Update Date Time", - ), - ], - ) - @pytest.mark.it("Has `RegistrationState` with properties that do not have setter") - def test_some_properties_of_state_are_not_settable(self, input_setter_code): - registration_state = create_registration_state() # noqa: F841 - - with pytest.raises(AttributeError, match="can't set attribute"): - exec(input_setter_code) - - @pytest.mark.it( - "Has a to string representation composed of device id, assigned hub and sub status" - ) - def test_registration_state_to_string_without_payload(self): - registration_state = create_registration_state() - # Serializes the __dict__ of every object instead of the object itself. - # Helpful for all sorts of complex objects. - json_payload = json.dumps(None, default=lambda o: o.__dict__, sort_keys=True) - - string_repr = "\n".join([fake_device_id, fake_assigned_hub, fake_sub_status, json_payload]) - assert str(registration_state) == string_repr - - @pytest.mark.it( - "Has a to string representation composed of device id, assigned hub, sub status and response payload" - ) - def test_registration_state_to_string_with_payload(self): - registration_state = create_registration_state(fake_payload) - json_payload = json.dumps(fake_payload, default=lambda o: o.__dict__, sort_keys=True) - - string_repr = "\n".join([fake_device_id, fake_assigned_hub, fake_sub_status, json_payload]) - assert str(registration_state) == string_repr - - -def create_registration_state(payload=None): - return RegistrationState( - fake_device_id, - fake_assigned_hub, - fake_sub_status, - fake_created_dttm, - fake_last_update_dttm, - fake_etag, - payload, - ) - - -def create_registration_result(registration_state=None): - return RegistrationResult(fake_operation_id, fake_status, registration_state) diff --git a/tests/unit/provisioning/pipeline/__init__.py b/tests/unit/provisioning/pipeline/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/provisioning/pipeline/conftest.py b/tests/unit/provisioning/pipeline/conftest.py deleted file mode 100644 index 5abca4b14..000000000 --- a/tests/unit/provisioning/pipeline/conftest.py +++ /dev/null @@ -1,14 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from tests.unit.common.pipeline.fixtures import ( # noqa: F401 - fake_pipeline_thread, - fake_non_pipeline_thread, - arbitrary_op, - arbitrary_event, - nucleus, - pipeline_connected_mock, -) diff --git a/tests/unit/provisioning/pipeline/helpers.py b/tests/unit/provisioning/pipeline/helpers.py deleted file mode 100644 index bc61db593..000000000 --- a/tests/unit/provisioning/pipeline/helpers.py +++ /dev/null @@ -1,16 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning - -all_provisioning_ops = [ - pipeline_ops_provisioning.RegisterOperation, - pipeline_ops_provisioning.PollStatusOperation, -] - -fake_key_values = {} -fake_key_values["request_id"] = ["request_1234", " "] -fake_key_values["retry-after"] = ["300", " "] -fake_key_values["name"] = ["hermione", " "] diff --git a/tests/unit/provisioning/pipeline/test_config.py b/tests/unit/provisioning/pipeline/test_config.py deleted file mode 100644 index ab0740a6a..000000000 --- a/tests/unit/provisioning/pipeline/test_config.py +++ /dev/null @@ -1,48 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -from tests.unit.common.pipeline.config_test import PipelineConfigInstantiationTestBase -from azure.iot.device.provisioning.pipeline.config import ProvisioningPipelineConfig - -hostname = "hostname.some-domain.net" -registration_id = "registration_id" -id_scope = "id_scope" - - -@pytest.mark.describe("ProvisioningPipelineConfig - Instantiation") -class TestProvisioningPipelineConfigInstantiation(PipelineConfigInstantiationTestBase): - @pytest.fixture - def config_cls(self): - # This fixture is needed for the parent class - return ProvisioningPipelineConfig - - @pytest.fixture - def required_kwargs(self): - # This fixture is needed for the parent class - return {"hostname": hostname, "registration_id": registration_id, "id_scope": id_scope} - - # The parent class defines the auth mechanism fixtures (sastoken, x509). - # For the sake of ease of testing, we will assume sastoken is being used unless - # there is a strict need to do something else. - # It does not matter which is used for the purposes of these tests. - - @pytest.mark.it( - "Instantiates with the 'registration_id' attribute set to the provided 'registration_id' parameter" - ) - def test_registration_id_set(self, sastoken): - config = ProvisioningPipelineConfig( - hostname=hostname, registration_id=registration_id, id_scope=id_scope, sastoken=sastoken - ) - assert config.registration_id == registration_id - - @pytest.mark.it( - "Instantiates with the 'id_scope' attribute set to the provided 'id_scope' parameter" - ) - def test_id_scope_set(self, sastoken): - config = ProvisioningPipelineConfig( - hostname=hostname, registration_id=registration_id, id_scope=id_scope, sastoken=sastoken - ) - assert config.id_scope == id_scope diff --git a/tests/unit/provisioning/pipeline/test_mqtt_pipeline.py b/tests/unit/provisioning/pipeline/test_mqtt_pipeline.py deleted file mode 100644 index 09085ff1e..000000000 --- a/tests/unit/provisioning/pipeline/test_mqtt_pipeline.py +++ /dev/null @@ -1,480 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.common.models import X509 -from azure.iot.device.provisioning.pipeline.mqtt_pipeline import MQTTPipeline -from azure.iot.device.provisioning.pipeline import constant as dps_constants -from azure.iot.device.provisioning.pipeline import ( - pipeline_stages_provisioning, - pipeline_stages_provisioning_mqtt, - pipeline_ops_provisioning, -) -from azure.iot.device.common.pipeline import ( - pipeline_stages_base, - pipeline_stages_mqtt, - pipeline_ops_base, - pipeline_exceptions, - pipeline_nucleus, -) - -logging.basicConfig(level=logging.DEBUG) -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -feature = dps_constants.REGISTER - - -def mock_x509(): - return X509(cert_file="fake_cert_file", key_file="fake_key_file", pass_phrase="some_password") - - -@pytest.fixture -def pipeline_configuration(mocker): - mock_config = mocker.MagicMock() - mock_config.sastoken.ttl = 1232 # set for compat - mock_config.sastoken.expiry_time = 1232131 # set for compat - mock_config.registration_id = "MyRegistration" - return mock_config - - -@pytest.fixture -def pipeline(mocker, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - mocker.patch.object(pipeline._pipeline, "run_op") - return pipeline - - -# automatically mock the transport for all tests in this file. -@pytest.fixture(autouse=True) -def mock_mqtt_transport(mocker): - return mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True - ) - - -@pytest.mark.describe("MQTTPipeline - Instantiation") -class TestMQTTPipelineInstantiation(object): - @pytest.mark.it("Begins tracking the enabled/disabled status of responses") - def test_features(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - pipeline.responses_enabled[feature] - # No assertion required - if this doesn't raise a KeyError, it is a success - - @pytest.mark.it("Marks responses as disabled") - def test_features_disabled(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - assert not pipeline.responses_enabled[feature] - - @pytest.mark.it("Sets all handlers to an initial value of None") - def test_handlers_set_to_none(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - assert pipeline.on_connected is None - assert pipeline.on_disconnected is None - assert pipeline.on_background_exception is None - assert pipeline.on_message_received is None - - @pytest.mark.it("Configures the pipeline to trigger handlers in response to external events") - def test_handlers_configured(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - assert pipeline._pipeline.on_pipeline_event_handler is not None - assert pipeline._pipeline.on_connected_handler is not None - assert pipeline._pipeline.on_disconnected_handler is not None - - @pytest.mark.it("Configures the pipeline with a PipelineNucleus") - def test_pipeline_nucleus(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - - assert isinstance(pipeline._nucleus, pipeline_nucleus.PipelineNucleus) - assert pipeline._nucleus.pipeline_configuration is pipeline_configuration - - @pytest.mark.it("Configures the pipeline with a series of PipelineStages") - def test_pipeline_stages(self, pipeline_configuration): - pipeline = MQTTPipeline(pipeline_configuration) - curr_stage = pipeline._pipeline - - expected_stage_order = [ - pipeline_stages_base.PipelineRootStage, - pipeline_stages_base.SasTokenStage, - pipeline_stages_provisioning.RegistrationStage, - pipeline_stages_provisioning.PollingStatusStage, - pipeline_stages_base.CoordinateRequestAndResponseStage, - pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage, - pipeline_stages_base.AutoConnectStage, - pipeline_stages_base.ConnectionStateStage, - pipeline_stages_base.RetryStage, - pipeline_stages_base.OpTimeoutStage, - pipeline_stages_mqtt.MQTTTransportStage, - ] - - # Assert that all PipelineStages are there, and they are in the right order - for i in range(len(expected_stage_order)): - expected_stage = expected_stage_order[i] - assert isinstance(curr_stage, expected_stage) - assert curr_stage.nucleus is pipeline._nucleus - curr_stage = curr_stage.next - - # Assert there are no more additional stages - assert curr_stage is None - - @pytest.mark.it("Runs an InitializePipelineOperation on the pipeline") - def test_init_pipeline(self, mocker, pipeline_configuration): - mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - - pipeline = MQTTPipeline(pipeline_configuration) - - op = pipeline._pipeline.run_op.call_args[0][1] - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_base.InitializePipelineOperation) - - @pytest.mark.it( - "Sets a flag to indicate the pipeline is 'running' upon successful completion of the InitializePipelineOperation" - ) - def test_running(self, mocker, pipeline_configuration): - # Because this is an init test, there isn't really a way to check that it only occurs after - # the op. The reason is because this is the object's init, the object doesn't actually - # exist until the entire method has completed, so there's no reference you can check prior - # to method completion. - pipeline = MQTTPipeline(pipeline_configuration) - assert pipeline._running - - @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the InitializePipelineOperation" - ) - def test_init_pipeline_failure(self, mocker, arbitrary_exception, pipeline_configuration): - old_run_op = pipeline_stages_base.PipelineRootStage._run_op - - def fail_set_security_client(self, op): - if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - op.complete(error=arbitrary_exception) - else: - old_run_op(self, op) - - mocker.patch.object( - pipeline_stages_base.PipelineRootStage, - "_run_op", - side_effect=fail_set_security_client, - autospec=True, - ) - - with pytest.raises(arbitrary_exception.__class__) as e_info: - MQTTPipeline(pipeline_configuration) - assert e_info.value is arbitrary_exception - - -@pytest.mark.describe("MQTTPipeline - .shutdown()") -class TestMQTTPipelineShutdown(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.shutdown(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a ShutdownPipelineOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - cb = mocker.MagicMock() - pipeline.shutdown(callback=cb) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.ShutdownPipelineOperation - ) - - @pytest.mark.it( - "Triggers the callback upon successful completion of the ShutdownPipelineOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.shutdown(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the ShutdownPipelineOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - - pipeline.shutdown(callback=cb) - op = pipeline._pipeline.run_op.call_args[0][0] - - op.complete(error=arbitrary_exception) - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - @pytest.mark.it( - "Sets a flag to indicate the pipeline is no longer running only upon successful completion of the ShutdownPipelineOperation" - ) - def test_set_not_running(self, mocker, pipeline, arbitrary_exception): - # Pipeline is running - assert pipeline._running - - # Begin operation (we will fail this one) - cb = mocker.MagicMock() - pipeline.shutdown(callback=cb) - assert cb.call_count == 0 - - # Pipeline is still running - assert pipeline._running - - # Trigger op completion (failure) - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - # Pipeline is still running - assert pipeline._running - - # Try operation again (we will make this one succeed) - cb.reset_mock() - pipeline.shutdown(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion (successful) - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - # Pipeline is no longer running - assert not pipeline._running - - -@pytest.mark.describe("MQTTPipeline - .connect()") -class TestMQTTPipelineConnect(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.connect(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a ConnectOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - cb = mocker.MagicMock() - pipeline.connect(callback=cb) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.ConnectOperation - ) - - @pytest.mark.it("Triggers the callback upon successful completion of the ConnectOperation") - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.connect(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the ConnectOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - - pipeline.connect(callback=cb) - op = pipeline._pipeline.run_op.call_args[0][0] - - op.complete(error=arbitrary_exception) - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .disconnect()") -class TestMQTTPipelineDisconnect(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.disconnect(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a DisconnectOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - pipeline.disconnect(callback=mocker.MagicMock()) - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance( - pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.DisconnectOperation - ) - - @pytest.mark.it("Triggers the callback upon successful completion of the DisconnectOperation") - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.disconnect(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the DisconnectOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.disconnect(callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) - - -@pytest.mark.describe("MQTTPipeline - .register()") -class TestSendRegister(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.register(callback=mocker.MagicMock()) - - @pytest.mark.it("Runs a RegisterOperation on the pipeline") - def test_runs_op(self, pipeline, mocker): - cb = mocker.MagicMock() - pipeline.register(callback=cb) - assert pipeline._pipeline.run_op.call_count == 1 - op = pipeline._pipeline.run_op.call_args[0][0] - assert isinstance(op, pipeline_ops_provisioning.RegisterOperation) - assert op.registration_id == pipeline._nucleus.pipeline_configuration.registration_id - - @pytest.mark.it("passes the payload parameter as request_payload on the RegistrationRequest") - def test_sets_request_payload(self, pipeline, mocker): - cb = mocker.MagicMock() - fake_request_payload = "fake_request_payload" - pipeline.register(payload=fake_request_payload, callback=cb) - op = pipeline._pipeline.run_op.call_args[0][0] - assert op.request_payload is fake_request_payload - - @pytest.mark.it( - "sets request_payload on the RegistrationRequest to None if no payload is provided" - ) - def test_sets_empty_payload(self, pipeline, mocker): - cb = mocker.MagicMock() - pipeline.register(callback=cb) - op = pipeline._pipeline.run_op.call_args[0][0] - assert op.request_payload is None - - @pytest.mark.it( - "Triggers the callback upon successful completion of the RegisterOperation, passing the registration result in the result parameter" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.register(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion - op = pipeline._pipeline.run_op.call_args[0][0] - fake_registration_result = "fake_result" - op.registration_result = fake_registration_result - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(result=fake_registration_result) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the RegisterOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - - pipeline.register(callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - fake_registration_result = "fake_result" - op.registration_result = fake_registration_result - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception, result=None) - - -@pytest.mark.describe("MQTTPipeline - .enable_responses()") -class TestEnable(object): - @pytest.mark.it( - "Raises a PipelineNotRunning exception if the pipeline is not running (i.e. already shut down)" - ) - def test_not_running(self, mocker, pipeline): - pipeline._running = False - - with pytest.raises(pipeline_exceptions.PipelineNotRunning): - pipeline.enable_responses(callback=mocker.MagicMock()) - - @pytest.mark.it("Marks the feature as enabled") - def test_mark_feature_enabled(self, pipeline, mocker): - assert not pipeline.responses_enabled[feature] - pipeline.enable_responses(callback=mocker.MagicMock()) - assert pipeline.responses_enabled[feature] - - @pytest.mark.it( - "Runs a EnableFeatureOperation on the pipeline, passing in the name of the feature" - ) - def test_runs_op(self, pipeline, mocker): - pipeline.enable_responses(callback=mocker.MagicMock()) - op = pipeline._pipeline.run_op.call_args[0][0] - - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_base.EnableFeatureOperation) - assert op.feature_name == dps_constants.REGISTER - - @pytest.mark.it( - "Triggers the callback upon successful completion of the EnableFeatureOperation" - ) - def test_op_success_with_callback(self, mocker, pipeline): - cb = mocker.MagicMock() - - # Begin operation - pipeline.enable_responses(callback=cb) - assert cb.call_count == 0 - - # Trigger op completion callback - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=None) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=None) - - @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the EnableFeatureOperation" - ) - def test_op_fail(self, mocker, pipeline, arbitrary_exception): - cb = mocker.MagicMock() - pipeline.enable_responses(callback=cb) - - op = pipeline._pipeline.run_op.call_args[0][0] - op.complete(error=arbitrary_exception) - - assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=arbitrary_exception) diff --git a/tests/unit/provisioning/pipeline/test_pipeline_ops_provisioning.py b/tests/unit/provisioning/pipeline/test_pipeline_ops_provisioning.py deleted file mode 100644 index 875484e24..000000000 --- a/tests/unit/provisioning/pipeline/test_pipeline_ops_provisioning.py +++ /dev/null @@ -1,122 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import sys -import logging -from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning -from tests.unit.common.pipeline import pipeline_ops_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -class RegisterOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_provisioning.RegisterOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "request_payload": "some_request_payload", - "registration_id": "some_registration_id", - "callback": mocker.MagicMock(), - } - return kwargs - - -class RegisterOperationInstantiationTests(RegisterOperationTestConfig): - @pytest.mark.it( - "Initializes 'request_payload' attribute with the provided 'request_payload' parameter" - ) - def test_request_payload(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_payload == init_kwargs["request_payload"] - - @pytest.mark.it( - "Initializes 'registration_id' attribute with the provided 'registration_id' parameter" - ) - def test_registration_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.registration_id == init_kwargs["registration_id"] - - @pytest.mark.it("Initializes 'retry_after_timer' attribute to None") - def test_retry_after_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.retry_after_timer is None - - @pytest.mark.it("Initializes 'polling_timer' attribute to None") - def test_polling_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.polling_timer is None - - @pytest.mark.it("Initializes 'provisioning_timeout_timer' attribute to None") - def test_provisioning_timeout_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.provisioning_timeout_timer is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_provisioning.RegisterOperation, - op_test_config_class=RegisterOperationTestConfig, - extended_op_instantiation_test_class=RegisterOperationInstantiationTests, -) - - -class PollStatusOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_provisioning.PollStatusOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "operation_id": "some_operation_id", - "request_payload": "some_request_payload", - "callback": mocker.MagicMock(), - } - return kwargs - - -class PollStatusOperationInstantiationTests(PollStatusOperationTestConfig): - @pytest.mark.it( - "Initializes 'operation_id' attribute with the provided 'operation_id' parameter" - ) - def test_operation_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.operation_id == init_kwargs["operation_id"] - - @pytest.mark.it( - "Initializes 'request_payload' attribute with the provided 'request_payload' parameter" - ) - def test_request_payload(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.request_payload == init_kwargs["request_payload"] - - @pytest.mark.it("Initializes 'retry_after_timer' attribute to None") - def test_retry_after_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.retry_after_timer is None - - @pytest.mark.it("Initializes 'polling_timer' attribute to None") - def test_polling_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.polling_timer is None - - @pytest.mark.it("Initializes 'provisioning_timeout_timer' attribute to None") - def test_provisioning_timeout_timer(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.provisioning_timeout_timer is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_provisioning.PollStatusOperation, - op_test_config_class=PollStatusOperationTestConfig, - extended_op_instantiation_test_class=PollStatusOperationInstantiationTests, -) diff --git a/tests/unit/provisioning/pipeline/test_pipeline_stages_provisioning.py b/tests/unit/provisioning/pipeline/test_pipeline_stages_provisioning.py deleted file mode 100644 index ac7b187b1..000000000 --- a/tests/unit/provisioning/pipeline/test_pipeline_stages_provisioning.py +++ /dev/null @@ -1,1028 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import sys -import json -import datetime -from azure.iot.device.provisioning.pipeline import ( - pipeline_stages_provisioning, - pipeline_ops_provisioning, -) -from azure.iot.device.common.pipeline import pipeline_ops_base -from tests.unit.common.pipeline import pipeline_stage_test -from azure.iot.device.exceptions import ServiceError - -from tests.unit.common.pipeline.helpers import StageRunOpTestBase -from azure.iot.device import exceptions -from azure.iot.device.provisioning.pipeline import constant - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -fake_device_id = "fake_device" -fake_registration_id = "fake_registration_id" -fake_provisioning_host = "hostname.com" -fake_id_scope = "fake_id_scope" -fake_ca_cert = "fake_ca_cert" -fake_sas_token = "fake_sas_token" -fake_request_id = "Request1234" -fake_operation_id = "Operation4567" -fake_status = "fake_status" -fake_assigned_hub = "MyIoTHub" -fake_sub_status = "fake_sub_status" -fake_created_dttm = datetime.datetime(2020, 5, 17) -fake_last_update_dttm = datetime.datetime(2020, 10, 17) -fake_etag = "SomeEtag" -fake_payload = "this is a fake payload" -fake_symmetric_key = "Zm9vYmFy" -fake_x509_cert_file = "fake_cert_file" -fake_x509_cert_key_file = "fake_cert_key_file" -fake_pass_phrase = "fake_pass_phrase" - - -class FakeRegistrationResult(object): - def __init__(self, operation_id, status, state): - self.operationId = operation_id - self.status = status - self.registrationState = state - - def __str__(self): - return "\n".join([str(self.registrationState), self.status]) - - -class FakeRegistrationState(object): - def __init__(self, payload): - self.deviceId = fake_device_id - self.assignedHub = fake_assigned_hub - self.payload = payload - self.substatus = fake_sub_status - - def __str__(self): - return "\n".join( - [self.deviceId, self.assignedHub, self.substatus, self.get_payload_string()] - ) - - def get_payload_string(self): - return json.dumps(self.payload, default=lambda o: o.__dict__, sort_keys=True) - - -def create_registration_result(fake_payload, status): - state = FakeRegistrationState(payload=fake_payload) - return FakeRegistrationResult(fake_operation_id, status, state) - - -def get_registration_result_as_bytes(registration_result): - return json.dumps(registration_result, default=lambda o: o.__dict__).encode("utf-8") - - -################### -# COMMON FIXTURES # -################### - - -@pytest.fixture(params=[True, False], ids=["With error", "No error"]) -def op_error(request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - -############################### -# REGISTRATION STAGE # -############################### - - -class RegistrationStageConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_provisioning.RegistrationStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_provisioning.RegistrationStage, - stage_test_config_class=RegistrationStageConfig, -) - - -@pytest.mark.describe("RegistrationStage - .run_op() -- called with RegisterOperation") -class TestRegistrationStageWithRegisterOperation(StageRunOpTestBase, RegistrationStageConfig): - @pytest.fixture(params=[" ", fake_payload], ids=["empty payload", "some payload"]) - def request_payload(self, request): - return request.param - - @pytest.fixture - def op(self, stage, mocker, request_payload): - op = pipeline_ops_provisioning.RegisterOperation( - request_payload, fake_registration_id, callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.fixture - def request_body(self, request_payload): - return '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(request_payload) - ) - - @pytest.mark.it( - "Sends a new RequestAndResponseOperation down the pipeline, configured to request a registration from provisioning service" - ) - def test_request_and_response_op(self, stage, op, request_body): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) - assert new_op.request_type == "register" - assert new_op.method == "PUT" - assert new_op.resource_location == "/" - assert new_op.request_body == request_body - - -@pytest.mark.describe("RegistrationStage - .run_op() -- Called with other arbitrary operation") -class TestRegistrationStageWithArbitraryOperation(StageRunOpTestBase, RegistrationStageConfig): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "RegistrationStage - OCCURRENCE: RequestAndResponseOperation created from RegisterOperation is completed" -) -class TestRegistrationStageWithRegisterOperationCompleted(RegistrationStageConfig): - @pytest.fixture(params=[" ", fake_payload], ids=["empty payload", "some payload"]) - def request_payload(self, request): - return request.param - - @pytest.fixture - def send_registration_op(self, mocker, request_payload): - op = pipeline_ops_provisioning.RegisterOperation( - request_payload, fake_registration_id, callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, send_registration_op): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - # Run the registration operation - stage.run_op(send_registration_op) - return stage - - @pytest.fixture - def request_and_response_op(self, stage): - assert stage.send_op_down.call_count == 1 - op = stage.send_op_down.call_args[0][0] - assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) - # reset the stage mock for convenience - stage.send_op_down.reset_mock() - return op - - @pytest.fixture - def request_body(self, request_payload): - return '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(request_payload) - ) - - @pytest.mark.it( - "Completes the RegisterOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(None, id="Status Code: None"), - pytest.param(200, id="Status Code: 200"), - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - @pytest.mark.parametrize( - "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] - ) - def test_request_and_response_op_completed_with_err( - self, - stage, - send_registration_op, - request_and_response_op, - status_code, - has_response_body, - arbitrary_exception, - ): - assert not send_registration_op.completed - assert not request_and_response_op.completed - - # NOTE: It shouldn't happen that an operation completed with error has a status code or a - # response body, but it IS possible. - request_and_response_op.status_code = status_code - if has_response_body: - request_and_response_op.response_body = b'{"key": "value"}' - request_and_response_op.complete(error=arbitrary_exception) - - assert request_and_response_op.completed - assert request_and_response_op.error is arbitrary_exception - assert send_registration_op.completed - assert send_registration_op.error is arbitrary_exception - assert send_registration_op.registration_result is None - - @pytest.mark.it( - "Completes the RegisterOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed with a status code >= 300 and less than 429" - ) - @pytest.mark.parametrize( - "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(428, id="Status Code: 428"), - ], - ) - def test_request_and_response_op_completed_success_with_bad_code( - self, stage, send_registration_op, request_and_response_op, status_code, has_response_body - ): - assert not send_registration_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = status_code - if has_response_body: - request_and_response_op.response_body = b'{"key": "value"}' - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_registration_op.completed - assert isinstance(send_registration_op.error, ServiceError) - # Twin is NOT returned - assert send_registration_op.registration_result is None - - @pytest.mark.it( - "Decodes, deserializes, and returns registration_result on the RegisterOperation op when RequestAndResponseOperation completes with no error if the status code < 300 and if status is 'assigned'" - ) - def test_request_and_response_op_completed_success_with_status_assigned( - self, stage, request_payload, send_registration_op, request_and_response_op - ): - registration_result = create_registration_result(request_payload, "assigned") - - assert not send_registration_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.retry_after = None - request_and_response_op.response_body = get_registration_result_as_bytes( - registration_result - ) - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_registration_op.completed - assert send_registration_op.error is None - # We need to assert string representations as these are inherently different objects - assert str(send_registration_op.registration_result) == str(registration_result) - - @pytest.mark.it( - "Decodes, deserializes, and returns registration_result along with an error on the RegisterOperation op when RequestAndResponseOperation completes with status code < 300 and status 'failed'" - ) - def test_request_and_response_op_completed_success_with_status_failed( - self, stage, request_payload, send_registration_op, request_and_response_op - ): - registration_result = create_registration_result(request_payload, "failed") - - assert not send_registration_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.retry_after = None - request_and_response_op.response_body = get_registration_result_as_bytes( - registration_result - ) - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_registration_op.completed - assert isinstance(send_registration_op.error, ServiceError) - # We need to assert string representations as these are inherently different objects - assert str(send_registration_op.registration_result) == str(registration_result) - assert "failed registration status" in str(send_registration_op.error) - - @pytest.mark.it( - "Returns error on the RegisterOperation op when RequestAndResponseOperation completes with status code < 300 and some unknown status" - ) - def test_request_and_response_op_completed_success_with_unknown_status( - self, stage, request_payload, send_registration_op, request_and_response_op - ): - registration_result = create_registration_result(request_payload, "some_status") - - assert not send_registration_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.retry_after = None - request_and_response_op.response_body = get_registration_result_as_bytes( - registration_result - ) - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_registration_op.completed - assert isinstance(send_registration_op.error, ServiceError) - assert "invalid registration status" in str(send_registration_op.error) - - @pytest.mark.it( - "Decodes, deserializes the response from RequestAndResponseOperation and creates another op if the status code < 300 and if status is 'assigning'" - ) - def test_spawns_another_op_request_and_response_op_completed_success_with_status_assigning( - self, mocker, stage, request_payload, send_registration_op, request_and_response_op - ): - mock_timer = mocker.patch( - "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" - ) - - mocker.spy(send_registration_op, "spawn_worker_op") - registration_result = create_registration_result(request_payload, "assigning") - - assert not send_registration_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.retry_after = None - request_and_response_op.response_body = get_registration_result_as_bytes( - registration_result - ) - request_and_response_op.complete() - - assert send_registration_op.retry_after_timer is None - assert send_registration_op.polling_timer is not None - timer_callback = mock_timer.call_args[0][1] - timer_callback() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert not send_registration_op.completed - assert send_registration_op.error is None - assert ( - send_registration_op.spawn_worker_op.call_args[1]["operation_id"] == fake_operation_id - ) - - -class RetryStageConfig(object): - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - mocker.spy(stage, "run_op") - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -@pytest.mark.describe("RegistrationStage - .run_op() -- retried again with RegisterOperation") -class TestRegistrationStageWithRetryOfRegisterOperation(RetryStageConfig): - @pytest.fixture(params=[" ", fake_payload], ids=["empty payload", "some payload"]) - def request_payload(self, request): - return request.param - - @pytest.fixture - def cls_type(self): - return pipeline_stages_provisioning.RegistrationStage - - @pytest.fixture - def op(self, stage, mocker, request_payload): - op = pipeline_ops_provisioning.RegisterOperation( - request_payload, fake_registration_id, callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.fixture - def request_body(self, request_payload): - return '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(request_payload) - ) - - @pytest.mark.it( - "Decodes, deserializes the response from RequestAndResponseOperation and retries the op if the status code > 429" - ) - def test_stage_retries_op_if_next_stage_responds_with_status_code_greater_than_429( - self, mocker, stage, op, request_body, request_payload - ): - mock_timer = mocker.patch( - "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" - ) - - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - next_op = stage.send_op_down.call_args[0][0] - assert isinstance(next_op, pipeline_ops_base.RequestAndResponseOperation) - - next_op.status_code = 430 - next_op.retry_after = "1" - registration_result = create_registration_result(request_payload, "some_status") - next_op.response_body = get_registration_result_as_bytes(registration_result) - next_op.complete() - - assert op.retry_after_timer is not None - assert op.polling_timer is None - timer_callback = mock_timer.call_args[0][1] - timer_callback() - - assert stage.run_op.call_count == 2 - assert stage.send_op_down.call_count == 2 - - next_op_2 = stage.send_op_down.call_args[0][0] - assert isinstance(next_op_2, pipeline_ops_base.RequestAndResponseOperation) - assert next_op_2.request_type == "register" - assert next_op_2.method == "PUT" - assert next_op_2.resource_location == "/" - assert next_op_2.request_body == request_body - - -@pytest.mark.describe( - "RegistrationStage - .run_op() -- Called with register request operation eligible for timeout" -) -class TestRegistrationStageWithTimeoutOfRegisterOperation( - StageRunOpTestBase, RegistrationStageConfig -): - @pytest.fixture - def op(self, stage, mocker): - op = pipeline_ops_provisioning.RegisterOperation( - " ", fake_registration_id, callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.fixture - def mock_timer(self, mocker): - return mocker.patch( - "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" - ) - - @pytest.mark.it( - "Adds a provisioning timeout timer with the interval specified in the configuration to the operation, and starts it" - ) - def test_adds_timer(self, mocker, stage, op, mock_timer): - stage.run_op(op) - - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call(constant.DEFAULT_TIMEOUT_INTERVAL, mocker.ANY) - assert op.provisioning_timeout_timer is mock_timer.return_value - assert op.provisioning_timeout_timer.start.call_count == 1 - assert op.provisioning_timeout_timer.start.call_args == mocker.call() - - @pytest.mark.it( - "Sends converted RequestResponse Op down the pipeline after attaching timer to the original op" - ) - def test_sends_down(self, mocker, stage, op, mock_timer): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) - - assert op.provisioning_timeout_timer is mock_timer.return_value - - @pytest.mark.it("Completes the operation unsuccessfully, with a ServiceError due to timeout") - def test_not_complete_timeout(self, mocker, stage, op, mock_timer): - # Apply the timer - stage.run_op(op) - assert not op.completed - assert mock_timer.call_count == 1 - on_timer_complete = mock_timer.call_args[0][1] - - # Call timer complete callback (indicating timer completion) - on_timer_complete() - - # Op is now completed with error - assert op.completed - assert isinstance(op.error, exceptions.ServiceError) - assert "register" in op.error.args[0] - - @pytest.mark.it( - "Completes the operation successfully, cancels and clears the operation's timeout timer" - ) - def test_complete_before_timeout(self, mocker, stage, op, mock_timer): - # Apply the timer - stage.run_op(op) - assert not op.completed - assert mock_timer.call_count == 1 - mock_timer_inst = op.provisioning_timeout_timer - assert mock_timer_inst is mock_timer.return_value - assert mock_timer_inst.cancel.call_count == 0 - - # Complete the next operation - new_op = stage.send_op_down.call_args[0][0] - new_op.status_code = 200 - new_op.response_body = "{}".encode("utf-8") - new_op.complete() - - # Timer is now cancelled and cleared - assert mock_timer_inst.cancel.call_count == 1 - assert mock_timer_inst.cancel.call_args == mocker.call() - assert op.provisioning_timeout_timer is None - - -class PollingStageConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_provisioning.PollingStatusStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_provisioning.PollingStatusStage, - stage_test_config_class=PollingStageConfig, -) - - -@pytest.mark.describe("PollingStatusStage - .run_op() -- called with PollStatusOperation") -class TestPollingStatusStageWithPollStatusOperation(StageRunOpTestBase, PollingStageConfig): - @pytest.fixture - def op(self, stage, mocker): - op = pipeline_ops_provisioning.PollStatusOperation( - fake_operation_id, " ", callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.polling_timer: - op.polling_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.mark.it( - "Sends a new RequestAndResponseOperation down the pipeline, configured to request a registration from provisioning service" - ) - def test_request_and_response_op(self, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) - assert new_op.request_type == "query" - assert new_op.method == "GET" - assert new_op.resource_location == "/" - assert new_op.request_body == " " - - -@pytest.mark.describe("PollingStatusStage - .run_op() -- Called with other arbitrary operation") -class TestPollingStatusStageWithArbitraryOperation(StageRunOpTestBase, PollingStageConfig): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "PollingStatusStage - OCCURRENCE: RequestAndResponseOperation created from PollStatusOperation is completed" -) -class TestPollingStatusStageWithPollStatusOperationCompleted(PollingStageConfig): - @pytest.fixture - def send_query_op(self, mocker): - op = pipeline_ops_provisioning.PollStatusOperation( - fake_operation_id, " ", callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.polling_timer: - op.polling_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, send_query_op): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - # Run the registration operation - stage.run_op(send_query_op) - return stage - - @pytest.fixture - def request_and_response_op(self, stage): - assert stage.send_op_down.call_count == 1 - op = stage.send_op_down.call_args[0][0] - assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) - # reset the stage mock for convenience - stage.send_op_down.reset_mock() - return op - - @pytest.mark.it( - "Completes the PollStatusOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(None, id="Status Code: None"), - pytest.param(200, id="Status Code: 200"), - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(500, id="Status Code: 500"), - ], - ) - @pytest.mark.parametrize( - "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] - ) - def test_request_and_response_op_completed_with_err( - self, - stage, - send_query_op, - request_and_response_op, - status_code, - has_response_body, - arbitrary_exception, - ): - assert not send_query_op.completed - assert not request_and_response_op.completed - - # NOTE: It shouldn't happen that an operation completed with error has a status code or a - # response body, but it IS possible. - request_and_response_op.status_code = status_code - if has_response_body: - request_and_response_op.response_body = b'{"key": "value"}' - request_and_response_op.complete(error=arbitrary_exception) - - assert request_and_response_op.completed - assert request_and_response_op.error is arbitrary_exception - assert send_query_op.completed - assert send_query_op.error is arbitrary_exception - assert send_query_op.registration_result is None - - @pytest.mark.it( - "Completes the PollStatusOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed with a status code >= 300 and less than 429" - ) - @pytest.mark.parametrize( - "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] - ) - @pytest.mark.parametrize( - "status_code", - [ - pytest.param(300, id="Status Code: 300"), - pytest.param(400, id="Status Code: 400"), - pytest.param(428, id="Status Code: 428"), - ], - ) - def test_request_and_response_op_completed_success_with_bad_code( - self, stage, send_query_op, request_and_response_op, status_code, has_response_body - ): - assert not send_query_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = status_code - if has_response_body: - request_and_response_op.response_body = b'{"key": "value"}' - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_query_op.completed - assert isinstance(send_query_op.error, ServiceError) - # Twin is NOT returned - assert send_query_op.registration_result is None - - @pytest.mark.it( - "Decodes, deserializes, and returns registration_result on the PollStatusOperation op when RequestAndResponseOperation completes with no error if the status code < 300 and if status is 'assigned'" - ) - def test_request_and_response_op_completed_success_with_status_assigned( - self, stage, send_query_op, request_and_response_op - ): - registration_result = create_registration_result(" ", "assigned") - - assert not send_query_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.retry_after = None - request_and_response_op.response_body = get_registration_result_as_bytes( - registration_result - ) - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_query_op.completed - assert send_query_op.error is None - # We need to assert string representations as these are inherently different objects - assert str(send_query_op.registration_result) == str(registration_result) - - @pytest.mark.it( - "Decodes, deserializes, and returns registration_result along with an error on the PollStatusOperation op when RequestAndResponseOperation completes with status code < 300 and status 'failed'" - ) - def test_request_and_response_op_completed_success_with_status_failed( - self, stage, send_query_op, request_and_response_op - ): - registration_result = create_registration_result(" ", "failed") - - assert not send_query_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.retry_after = None - request_and_response_op.response_body = get_registration_result_as_bytes( - registration_result - ) - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_query_op.completed - assert isinstance(send_query_op.error, ServiceError) - # We need to assert string representations as these are inherently different objects - assert str(send_query_op.registration_result) == str(registration_result) - assert "failed registration status" in str(send_query_op.error) - - @pytest.mark.it( - "Returns error on the PollStatusOperation op when RequestAndResponseOperation completes with status code < 300 and some unknown status" - ) - def test_request_and_response_op_completed_success_with_unknown_status( - self, stage, send_query_op, request_and_response_op - ): - registration_result = create_registration_result(" ", "some_status") - - assert not send_query_op.completed - assert not request_and_response_op.completed - - request_and_response_op.status_code = 200 - request_and_response_op.retry_after = None - request_and_response_op.response_body = get_registration_result_as_bytes( - registration_result - ) - request_and_response_op.complete() - - assert request_and_response_op.completed - assert request_and_response_op.error is None - assert send_query_op.completed - assert isinstance(send_query_op.error, ServiceError) - assert "invalid registration status" in str(send_query_op.error) - - -@pytest.mark.describe("PollingStatusStage - .run_op() -- retried again with PollStatusOperation") -class TestPollingStatusStageWithPollStatusRetryOperation(RetryStageConfig): - @pytest.fixture - def cls_type(self): - return pipeline_stages_provisioning.PollingStatusStage - - @pytest.fixture - def op(self, stage, mocker): - op = pipeline_ops_provisioning.PollStatusOperation( - fake_operation_id, " ", callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.polling_timer: - op.polling_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.mark.it( - "Decodes, deserializes the response from RequestAndResponseOperation and retries the op if the status code > 429" - ) - def test_stage_retries_op_if_next_stage_responds_with_status_code_greater_than_429( - self, mocker, stage, op - ): - mock_timer = mocker.patch( - "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" - ) - - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - next_op = stage.send_op_down.call_args[0][0] - assert isinstance(next_op, pipeline_ops_base.RequestAndResponseOperation) - - next_op.status_code = 430 - next_op.retry_after = "1" - registration_result = create_registration_result(" ", "some_status") - next_op.response_body = get_registration_result_as_bytes(registration_result) - next_op.complete() - - assert op.retry_after_timer is not None - assert op.polling_timer is None - timer_callback = mock_timer.call_args[0][1] - timer_callback() - - assert stage.run_op.call_count == 2 - assert stage.send_op_down.call_count == 2 - - next_op_2 = stage.send_op_down.call_args[0][0] - assert isinstance(next_op_2, pipeline_ops_base.RequestAndResponseOperation) - assert next_op_2.request_type == "query" - assert next_op_2.method == "GET" - assert next_op_2.resource_location == "/" - assert next_op_2.request_body == " " - - @pytest.mark.it( - "Decodes, deserializes the response from RequestAndResponseOperation and retries the op if the status code < 300 and if status is 'assigning'" - ) - def test_stage_retries_op_if_next_stage_responds_with_status_assigning(self, mocker, stage, op): - mock_timer = mocker.patch( - "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" - ) - - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - next_op = stage.send_op_down.call_args[0][0] - assert isinstance(next_op, pipeline_ops_base.RequestAndResponseOperation) - - next_op.status_code = 228 - next_op.retry_after = "1" - registration_result = create_registration_result(" ", "assigning") - next_op.response_body = get_registration_result_as_bytes(registration_result) - next_op.complete() - - assert op.retry_after_timer is None - assert op.polling_timer is not None - timer_callback = mock_timer.call_args[0][1] - timer_callback() - - assert stage.run_op.call_count == 2 - assert stage.send_op_down.call_count == 2 - - next_op_2 = stage.send_op_down.call_args[0][0] - assert isinstance(next_op_2, pipeline_ops_base.RequestAndResponseOperation) - assert next_op_2.request_type == "query" - assert next_op_2.method == "GET" - assert next_op_2.resource_location == "/" - assert next_op_2.request_body == " " - - -@pytest.mark.describe( - "RegistrationStage - .run_op() -- Called with register request operation eligible for timeout" -) -class TestPollingStageWithTimeoutOfQueryOperation(StageRunOpTestBase, PollingStageConfig): - @pytest.fixture - def op(self, stage, mocker): - op = pipeline_ops_provisioning.PollStatusOperation( - fake_operation_id, " ", callback=mocker.MagicMock() - ) - yield op - - # Clean up any timers set on it - if op.polling_timer: - op.polling_timer.cancel() - if op.retry_after_timer: - op.retry_after_timer.cancel() - if op.provisioning_timeout_timer: - op.provisioning_timeout_timer.cancel() - - @pytest.fixture - def mock_timer(self, mocker): - return mocker.patch( - "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" - ) - - @pytest.mark.it( - "Adds a provisioning timeout timer with the interval specified in the configuration to the operation, and starts it" - ) - def test_adds_timer(self, mocker, stage, op, mock_timer): - stage.run_op(op) - - assert mock_timer.call_count == 1 - assert mock_timer.call_args == mocker.call(constant.DEFAULT_TIMEOUT_INTERVAL, mocker.ANY) - assert op.provisioning_timeout_timer is mock_timer.return_value - assert op.provisioning_timeout_timer.start.call_count == 1 - assert op.provisioning_timeout_timer.start.call_args == mocker.call() - - @pytest.mark.it( - "Sends converted RequestResponse Op down the pipeline after attaching timer to the original op" - ) - def test_sends_down(self, mocker, stage, op, mock_timer): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) - - assert op.provisioning_timeout_timer is mock_timer.return_value - - @pytest.mark.it("Completes the operation unsuccessfully, with a ServiceError due to timeout") - def test_not_complete_timeout(self, mocker, stage, op, mock_timer): - # Apply the timer - stage.run_op(op) - assert not op.completed - assert mock_timer.call_count == 1 - on_timer_complete = mock_timer.call_args[0][1] - - # Call timer complete callback (indicating timer completion) - on_timer_complete() - - # Op is now completed with error - assert op.completed - assert isinstance(op.error, exceptions.ServiceError) - assert "query" in op.error.args[0] - - @pytest.mark.it( - "Completes the operation successfully, cancels and clears the operation's timeout timer" - ) - def test_complete_before_timeout(self, mocker, stage, op, mock_timer): - # Apply the timer - stage.run_op(op) - assert not op.completed - assert mock_timer.call_count == 1 - mock_timer_inst = op.provisioning_timeout_timer - assert mock_timer_inst is mock_timer.return_value - assert mock_timer_inst.cancel.call_count == 0 - - # Complete the next operation - new_op = stage.send_op_down.call_args[0][0] - new_op.status_code = 200 - new_op.response_body = "{}".encode("utf-8") - new_op.complete() - - # Timer is now cancelled and cleared - assert mock_timer_inst.cancel.call_count == 1 - assert mock_timer_inst.cancel.call_args == mocker.call() - assert op.provisioning_timeout_timer is None diff --git a/tests/unit/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py b/tests/unit/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py deleted file mode 100644 index 99c1d98cd..000000000 --- a/tests/unit/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py +++ /dev/null @@ -1,479 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -import sys -import urllib -from azure.iot.device import constant as pkg_constant -from azure.iot.device.common.pipeline import ( - pipeline_ops_base, - pipeline_ops_mqtt, - pipeline_events_mqtt, - pipeline_events_base, - pipeline_exceptions, -) -from azure.iot.device.provisioning.pipeline import ( - config, - pipeline_stages_provisioning_mqtt, -) -from azure.iot.device.provisioning.pipeline import constant as pipeline_constant -from azure.iot.device import user_agent -from tests.unit.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase -from tests.unit.common.pipeline import pipeline_stage_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] -pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") - - -@pytest.fixture(params=[True, False], ids=["With error", "No error"]) -def op_error(request, arbitrary_exception): - if request.param: - return arbitrary_exception - else: - return None - - -@pytest.fixture -def mock_mqtt_topic(mocker): - m = mocker.patch( - "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning_mqtt.mqtt_topic_provisioning" - ) - return m - - -class ProvisioningMQTTTranslationStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def pipeline_config(self, mocker): - # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. - cfg = config.ProvisioningPipelineConfig( - hostname="http://my.hostname", - registration_id="fake_reg_id", - id_scope="fake_id_scope", - sastoken=mocker.MagicMock(), - ) - return cfg - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs, nucleus, pipeline_config): - stage = cls_type(**init_kwargs) - stage.nucleus = nucleus - stage.nucleus.pipeline_configuration = pipeline_config - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - mocker.spy(stage, "report_background_exception") - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage, - stage_test_config_class=ProvisioningMQTTTranslationStageTestConfig, -) - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .run_op() -- Called with InitializePipelineOperation" -) -class TestProvisioningMQTTTranslationStageRunOpWithInitializePipelineOperation( - StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - - @pytest.mark.it("Derives the MQTT client id, and sets it on the op") - def test_client_id(self, stage, op, pipeline_config): - assert not hasattr(op, "client_id") - stage.run_op(op) - - assert op.client_id == pipeline_config.registration_id - - @pytest.mark.it("Derives the MQTT username, and sets it on the op") - def test_username(self, stage, op, pipeline_config): - assert not hasattr(op, "username") - stage.run_op(op) - - expected_username = "{id_scope}/registrations/{registration_id}/api-version={api_version}&ClientVersion={user_agent}".format( - id_scope=pipeline_config.id_scope, - registration_id=pipeline_config.registration_id, - api_version=pkg_constant.PROVISIONING_API_VERSION, - user_agent=urllib.parse.quote(user_agent.get_provisioning_user_agent(), safe=""), - ) - assert op.username == expected_username - - @pytest.mark.it("Sends the op down the pipeline") - def test_sends_down(self, mocker, stage, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .run_op() -- Called with RequestOperation (Register Request)" -) -class TestProvisioningMQTTTranslationStageRunOpWithRequestOperationRegister( - StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.RequestOperation( - request_type=pipeline_constant.REGISTER, - method="PUT", - resource_location="/", - request_body='{"json": "payload"}', - request_id="fake_request_id", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Derives the Provisioning Register Request topic using the op's details") - def test_register_request_topic(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert mock_mqtt_topic.get_register_topic_for_publish.call_count == 1 - assert mock_mqtt_topic.get_register_topic_for_publish.call_args == mocker.call( - request_id=op.request_id - ) - - @pytest.mark.it( - "Sends a new MQTTPublishOperation down the pipeline with the original op's request body and the derived topic string" - ) - def test_sends_mqtt_publish_down(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert new_op.topic == mock_mqtt_topic.get_register_topic_for_publish.return_value - assert new_op.payload == op.request_body - - @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - assert op.error is None - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .run_op() -- Called with RequestOperation (Query Request)" -) -class TestProvisioningMQTTTranslationStageRunOpWithRequestOperationQuery( - StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.RequestOperation( - request_type=pipeline_constant.QUERY, - method="GET", - resource_location="/", - query_params={"operation_id": "fake_op_id"}, - request_body="some body", - request_id="fake_request_id", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Derives the Provisioning Query Request topic using the op's details") - def test_register_request_topic(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert mock_mqtt_topic.get_query_topic_for_publish.call_count == 1 - assert mock_mqtt_topic.get_query_topic_for_publish.call_args == mocker.call( - request_id=op.request_id, operation_id=op.query_params["operation_id"] - ) - - @pytest.mark.it( - "Sends a new MQTTPublishOperation down the pipeline with the original op's request body and the derived topic string" - ) - def test_sends_mqtt_publish_down(self, mocker, stage, op, mock_mqtt_topic): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert new_op.topic == mock_mqtt_topic.get_query_topic_for_publish.return_value - assert new_op.payload == op.request_body - - @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - assert op.error is None - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .run_op() -- Called with RequestOperation (Unsupported Request Type)" -) -class TestProvisioningMQTTTranslationStageRunOpWithRequestOperationUnsupportedType( - StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.RequestOperation( - request_type="FAKE_REQUEST_TYPE", - method="GET", - resource_location="/", - request_body="some body", - request_id="fake_request_id", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Completes the operation with an OperationError failure") - def test_fail(self, mocker, stage, op): - assert not op.completed - assert op.error is None - - stage.run_op(op) - - assert op.completed - assert isinstance(op.error, pipeline_exceptions.OperationError) - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .run_op() -- Called with EnableFeatureOperation" -) -class TestProvisioningMQTTTranslationStageRunOpWithEnableFeatureOperation( - StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.EnableFeatureOperation( - feature_name=pipeline_constant.REGISTER, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sends a new MQTTSubscribeOperation down the pipeline, containing the subscription topic for Register, if Register is the feature being enabled" - ) - def test_mqtt_subscribe_sent_down(self, mocker, op, stage, mock_mqtt_topic): - stage.run_op(op) - - # Topic was derived as expected - assert mock_mqtt_topic.get_register_topic_for_subscribe.call_count == 1 - assert mock_mqtt_topic.get_register_topic_for_subscribe.call_args == mocker.call() - - # New op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTSubscribeOperation) - - # New op has the expected topic - assert new_op.topic == mock_mqtt_topic.get_register_topic_for_subscribe.return_value - - @pytest.mark.it("Completes the original op upon completion of the new MQTTSubscribeOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - assert op.error is None - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - @pytest.mark.it( - "Completes the operation with an OperationError failure if the feature being enabled is of any type other than Register" - ) - def test_unsupported_feature(self, stage, op): - op.feature_name = "invalid feature" - assert not op.completed - assert op.error is None - - stage.run_op(op) - - assert op.completed - assert isinstance(op.error, pipeline_exceptions.OperationError) - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .run_op() -- Called with DisableFeatureOperation" -) -class TestProvisioningMQTTTranslationStageRunOpWithDisableFeatureOperation( - StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.DisableFeatureOperation( - feature_name=pipeline_constant.REGISTER, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sends a new MQTTUnsubscribeOperation down the pipeline, containing the subscription topic for Register, if Register is the feature being disabled" - ) - def test_mqtt_unsubscribe_sent_down(self, mocker, op, stage, mock_mqtt_topic): - stage.run_op(op) - - # Topic was derived as expected - assert mock_mqtt_topic.get_register_topic_for_subscribe.call_count == 1 - assert mock_mqtt_topic.get_register_topic_for_subscribe.call_args == mocker.call() - - # New op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTUnsubscribeOperation) - - # New op has the expected topic - assert new_op.topic == mock_mqtt_topic.get_register_topic_for_subscribe.return_value - - @pytest.mark.it("Completes the original op upon completion of the new MQTTUnsubscribeOperation") - def test_complete_resulting_op(self, stage, op, op_error): - stage.run_op(op) - assert not op.completed - assert op.error is None - - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - @pytest.mark.it( - "Completes the operation with an OperationError failure if the feature being disabled is of any type other than Register" - ) - def test_unsupported_feature(self, stage, op): - op.feature_name = "invalid feature" - assert not op.completed - assert op.error is None - - stage.run_op(op) - - assert op.completed - assert isinstance(op.error, pipeline_exceptions.OperationError) - - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- Called with other arbitrary operation" -) -class TestProvisioningMQTTTranslationStageRunOpWithArbitraryOperation( - StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - -@pytest.mark.it( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (DPS Response Topic)" -) -class TestProvisioningMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventDPSResponseTopic( - StageHandlePipelineEventTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def status(self): - return 200 - - @pytest.fixture - def rid(self): - return "3226c2f7-3d30-425c-b83b-0c34335f8220" - - @pytest.fixture(params=["With retry-after", "No retry-after"]) - def retry_after(self, request): - if request.param == "With retry-after": - return "1234" - else: - return None - - @pytest.fixture - def event(self, status, rid, retry_after): - topic = "$dps/registrations/res/{status}/?$rid={rid}".format(status=status, rid=rid) - if retry_after: - topic = topic + "&retry-after={}".format(retry_after) - return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some payload") - - @pytest.mark.it( - "Sends a ResponseEvent up the pipeline containing the original event's payload and values extracted from the topic string" - ) - def test_response_event(self, event, stage, status, rid, retry_after): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - new_event = stage.send_event_up.call_args[0][0] - assert isinstance(new_event, pipeline_events_base.ResponseEvent) - assert new_event.status_code == status - assert new_event.request_id == rid - assert new_event.retry_after == retry_after - assert new_event.response_body == event.payload - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Unrecognized topic string)" -) -class TestProvisioningMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventUnknownTopicString( - StageHandlePipelineEventTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def event(self): - topic = "not a real topic" - return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some payload") - - @pytest.mark.it("Sends the event up the pipeline") - def test_sends_up(self, event, stage): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args[0][0] == event - - -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage - .handle_pipeline_event() -- Called with other arbitrary event" -) -class TestProvisioningMQTTTranslationStageHandlePipelineEventWithArbitraryEvent( - StageHandlePipelineEventTestBase, ProvisioningMQTTTranslationStageTestConfig -): - @pytest.fixture - def event(self, arbitrary_event): - return arbitrary_event - - @pytest.mark.it("Sends the event up the pipeline") - def test_sends_up(self, event, stage): - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args[0][0] == event diff --git a/tests/unit/provisioning/shared_client_fixtures.py b/tests/unit/provisioning/shared_client_fixtures.py deleted file mode 100644 index 9c53d5d18..000000000 --- a/tests/unit/provisioning/shared_client_fixtures.py +++ /dev/null @@ -1,69 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains test fixtures shared between sync/async client tests""" -import pytest -from azure.iot.device.provisioning.models.registration_result import ( - RegistrationResult, - RegistrationState, -) -from azure.iot.device.common.models.x509 import X509 - -"""Constants""" -fake_x509_cert_file_value = "fake_cert_file" -fake_x509_cert_key_file = "fake_cert_key_file" -fake_pass_phrase = "fake_pass_phrase" -fake_status = "200" -fake_sub_status = "OK" -fake_operation_id = "fake_operation_id" -fake_device_id = "MyDevice" -fake_assigned_hub = "MyIoTHub" - - -"""Pipeline fixtures""" - - -@pytest.fixture -def mock_pipeline_init(mocker): - return mocker.patch("azure.iot.device.provisioning.pipeline.MQTTPipeline") - - -@pytest.fixture(autouse=True) -def provisioning_pipeline(mocker): - return mocker.MagicMock(wraps=FakeProvisioningPipeline()) - - -class FakeProvisioningPipeline: - def __init__(self): - self.responses_enabled = {} - - def shutdown(self, callback): - callback() - - def connect(self, callback): - callback() - - def disconnect(self, callback): - callback() - - def enable_responses(self, callback): - callback() - - def register(self, payload, callback): - callback(result={}) - - -"""Parameter fixtures""" - - -@pytest.fixture -def registration_result(): - registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) - return RegistrationResult(fake_operation_id, fake_status, registration_state) - - -@pytest.fixture -def x509(): - return X509(fake_x509_cert_file_value, fake_x509_cert_key_file, fake_pass_phrase) diff --git a/tests/unit/provisioning/shared_client_tests.py b/tests/unit/provisioning/shared_client_tests.py deleted file mode 100644 index fb5f5a606..000000000 --- a/tests/unit/provisioning/shared_client_tests.py +++ /dev/null @@ -1,388 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains tests that are shared between sync/async clients -i.e. tests for things defined in abstract clients""" - -import pytest -import logging - -from azure.iot.device.common import auth, handle_exceptions -from azure.iot.device.common.auth import sastoken as st -from azure.iot.device.provisioning.pipeline import ProvisioningPipelineConfig -from azure.iot.device import ProxyOptions -from azure.iot.device.common.pipeline.config import DEFAULT_KEEPALIVE - -logging.basicConfig(level=logging.DEBUG) - - -fake_provisioning_host = "hostname.com" -fake_registration_id = "MyRegId" -fake_id_scope = "Some0000Scope7898" -fake_symmetric_key = "Zm9vYmFy" - - -class SharedProvisioningClientInstantiationTests(object): - @pytest.mark.it( - "Stores the ProvisioningPipeline from the 'pipeline' parameter in the '_pipeline' attribute" - ) - def test_sets_provisioning_pipeline(self, client_class, provisioning_pipeline): - client = client_class(provisioning_pipeline) - - assert client._pipeline is provisioning_pipeline - - @pytest.mark.it( - "Sets the pipeline's `on_background_exception` handler to the `handle_background_exception` function from the `handle_exceptions` module" - ) - def test_pipeline_on_background_exception(self, client_class, provisioning_pipeline): - client = client_class(provisioning_pipeline) - assert ( - client._pipeline.on_background_exception - is handle_exceptions.handle_background_exception - ) - - @pytest.mark.it( - "Instantiates with the initial value of the '_provisioning_payload' attribute set to None" - ) - def test_payload(self, client_class, provisioning_pipeline): - client = client_class(provisioning_pipeline) - - assert client._provisioning_payload is None - - -class SharedProvisioningClientCreateMethodUserOptionTests(object): - @pytest.mark.it( - "Sets the 'server_verification_cert' user option parameter on the PipelineConfig, if provided" - ) - def test_server_verification_cert_option( - self, client_create_method, create_method_args, mock_pipeline_init - ): - server_verification_cert = "fake_server_verification_cert" - client_create_method(*create_method_args, server_verification_cert=server_verification_cert) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][0] - assert isinstance(config, ProvisioningPipelineConfig) - - assert config.server_verification_cert == server_verification_cert - - @pytest.mark.it( - "Sets the 'gateway_hostname' user option parameter on the PipelineConfig, if provided" - ) - def test_gateway_hostname_option( - self, client_create_method, create_method_args, mock_pipeline_init - ): - gateway_hostname = "my.gateway.hostname" - client_create_method(*create_method_args, gateway_hostname=gateway_hostname) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][0] - assert isinstance(config, ProvisioningPipelineConfig) - - assert config.gateway_hostname == gateway_hostname - - @pytest.mark.it( - "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" - ) - def test_websockets_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - client_create_method(*create_method_args, websockets=True) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][0] - assert isinstance(config, ProvisioningPipelineConfig) - - assert config.websockets - - # TODO: Show that input in the wrong format is formatted to the correct one. This test exists - # in the ProvisioningPipelineConfig object already, but we do not currently show that this is felt - # from the API level. - @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") - def test_cipher_option(self, client_create_method, create_method_args, mock_pipeline_init): - - cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" - client_create_method(*create_method_args, cipher=cipher) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][0] - assert isinstance(config, ProvisioningPipelineConfig) - - assert config.cipher == cipher - - @pytest.mark.it("Sets the 'proxy_options' user option parameter on the PipelineConfig") - def test_proxy_options(self, client_create_method, create_method_args, mock_pipeline_init): - proxy_options = ProxyOptions(proxy_type="HTTP", proxy_addr="127.0.0.1", proxy_port=8888) - client_create_method(*create_method_args, proxy_options=proxy_options) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][0] - assert isinstance(config, ProvisioningPipelineConfig) - - assert config.proxy_options is proxy_options - - @pytest.mark.it( - "Sets the 'keep_alive' user option parameter on the PipelineConfig, if provided" - ) - def test_keep_alive_options(self, client_create_method, create_method_args, mock_pipeline_init): - keepalive_value = 60 - client_create_method(*create_method_args, keep_alive=keepalive_value) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][0] - assert isinstance(config, ProvisioningPipelineConfig) - - assert config.keep_alive == keepalive_value - - @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") - def test_invalid_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - with pytest.raises(TypeError): - client_create_method(*create_method_args, invalid_option="some_value") - - @pytest.mark.it("Sets default user options if none are provided") - def test_default_options( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - client_create_method(*create_method_args) - - # Pipeline uses a ProvisioningPipelineConfig - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][0] - assert isinstance(config, ProvisioningPipelineConfig) - - # ProvisioningPipelineConfig has default options set that were not user-specified - assert config.server_verification_cert is None - assert config.gateway_hostname is None - assert config.websockets is False - assert config.cipher == "" - assert config.proxy_options is None - assert config.keep_alive == DEFAULT_KEEPALIVE - - -@pytest.mark.usefixtures("mock_pipeline_init") -class SharedProvisioningClientCreateFromSymmetricKeyTests( - SharedProvisioningClientCreateMethodUserOptionTests -): - @pytest.fixture - def client_create_method(self, client_class): - return client_class.create_from_symmetric_key - - @pytest.fixture - def create_method_args(self): - return [fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key] - - @pytest.mark.it( - "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values provided in parameters" - ) - def test_sastoken(self, mocker, client_class): - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - expected_uri = "{id_scope}/registrations/{registration_id}".format( - id_scope=fake_id_scope, registration_id=fake_registration_id - ) - - custom_ttl = 1000 - client_class.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - sastoken_ttl=custom_ttl, - ) - - # SymmetricKeySigningMechanism created using the provided symmetric key - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=fake_symmetric_key) - - # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the custom ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=custom_ttl - ) - - @pytest.mark.it( - "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" - ) - def test_sastoken_default(self, mocker, client_class): - sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - expected_uri = "{id_scope}/registrations/{registration_id}".format( - id_scope=fake_id_scope, registration_id=fake_registration_id - ) - - client_class.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - # SymmetricKeySigningMechanism created using the provided symmetric key - assert sksm_mock.call_count == 1 - assert sksm_mock.call_args == mocker.call(key=fake_symmetric_key) - - # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the default ttl - assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call( - expected_uri, sksm_mock.return_value, ttl=3600 - ) - - @pytest.mark.it( - "Creates an MQTT pipeline with a ProvisioningPipelineConfig object containing the SasToken and values provided in the parameters" - ) - def test_pipeline_config(self, mocker, client_class, mock_pipeline_init): - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - - client_class.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - # Verify pipeline was created with a ProvisioningPipelineConfig - assert mock_pipeline_init.call_count == 1 - assert isinstance(mock_pipeline_init.call_args[0][0], ProvisioningPipelineConfig) - - # Verify the ProvisioningPipelineConfig is constructed as expected - config = mock_pipeline_init.call_args[0][0] - assert config.hostname == fake_provisioning_host - assert config.gateway_hostname is None - assert config.registration_id == fake_registration_id - assert config.id_scope == fake_id_scope - assert config.sastoken is sastoken_mock.return_value - - @pytest.mark.it( - "Returns an instance of a ProvisioningDeviceClient using the created MQTT pipeline" - ) - def test_client_returned(self, mocker, client_class, mock_pipeline_init): - client = client_class.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - assert isinstance(client, client_class) - assert client._pipeline is mock_pipeline_init.return_value - - @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") - def test_sastoken_failure(self, mocker, client_class): - sastoken_mock = mocker.patch.object(st, "RenewableSasToken") - token_err = st.SasTokenError("Some SasToken failure") - sastoken_mock.side_effect = token_err - - with pytest.raises(ValueError) as e_info: - client_class.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - assert e_info.value.__cause__ is token_err - - @pytest.mark.parametrize( - "registration_id", - [ - pytest.param(None, id="No Registration Id provided"), - pytest.param(" ", id="Blank Registration Id provided"), - pytest.param("", id="Empty Registration Id provided"), - ], - ) - @pytest.mark.it("Raises a ValueError if an invalid 'registration_id' parameter is provided") - def test_invalid_registration_id(self, client_class, registration_id): - with pytest.raises(ValueError): - client_class.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - -@pytest.mark.usefixtures("mock_pipeline_init") -class SharedProvisioningClientCreateFromX509CertificateTests( - SharedProvisioningClientCreateMethodUserOptionTests -): - @pytest.fixture - def client_create_method(self, client_class): - return client_class.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - return [fake_provisioning_host, fake_registration_id, fake_id_scope, x509] - - @pytest.mark.it( - "Creates MQTT pipeline with a ProvisioningPipelineConfig object containing the X509 and other values provided in parameters" - ) - def test_pipeline_config(self, mocker, client_class, x509, mock_pipeline_init): - client_class.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - - # Verify pipeline created with a ProvisioningPipelineConfig - assert mock_pipeline_init.call_count == 1 - assert isinstance(mock_pipeline_init.call_args[0][0], ProvisioningPipelineConfig) - - # Verify the ProvisioningPipelineConfig is constructed as expected - config = mock_pipeline_init.call_args[0][0] - assert config.hostname == fake_provisioning_host - assert config.gateway_hostname is None - assert config.registration_id == fake_registration_id - assert config.id_scope == fake_id_scope - assert config.x509 is x509 - - @pytest.mark.it( - "Returns an instance of a ProvisioningDeviceClient using the created MQTT pipeline" - ) - def test_client_returned(self, mocker, client_class, x509, mock_pipeline_init): - client = client_class.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - - assert isinstance(client, client_class) - assert client._pipeline is mock_pipeline_init.return_value - - @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") - def test_sastoken_ttl(self, client_class, x509): - with pytest.raises(TypeError): - client_class.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - sastoken_ttl=1000, - ) - - @pytest.mark.parametrize( - "registration_id", - [ - pytest.param(None, id="No Registration Id provided"), - pytest.param(" ", id="Blank Registration Id provided"), - pytest.param("", id="Empty Registration Id provided"), - ], - ) - @pytest.mark.it("Raises a ValueError if an invalid 'registration_id' parameter is provided") - def test_invalid_registration_id(self, client_class, registration_id, x509): - with pytest.raises(ValueError): - client_class.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=registration_id, - id_scope=fake_id_scope, - x509=x509, - ) diff --git a/tests/unit/provisioning/test_sync_provisioning_device_client.py b/tests/unit/provisioning/test_sync_provisioning_device_client.py deleted file mode 100644 index f80f9e5d6..000000000 --- a/tests/unit/provisioning/test_sync_provisioning_device_client.py +++ /dev/null @@ -1,346 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import logging -from azure.iot.device.provisioning.provisioning_device_client import ProvisioningDeviceClient -from azure.iot.device.provisioning.pipeline import exceptions as pipeline_exceptions -from azure.iot.device import exceptions as client_exceptions -from .shared_client_tests import ( - SharedProvisioningClientInstantiationTests, - SharedProvisioningClientCreateFromSymmetricKeyTests, - SharedProvisioningClientCreateFromX509CertificateTests, -) - - -logging.basicConfig(level=logging.DEBUG) - - -class ProvisioningClientTestsConfig(object): - """Defines fixtures for synchronous ProvisioningDeviceClient tests""" - - @pytest.fixture - def client_class(self): - return ProvisioningDeviceClient - - @pytest.fixture - def client(self, provisioning_pipeline): - return ProvisioningDeviceClient(provisioning_pipeline) - - -@pytest.mark.describe("ProvisioningDeviceClient (Sync) - Instantiation") -class TestProvisioningClientInstantiation( - ProvisioningClientTestsConfig, SharedProvisioningClientInstantiationTests -): - pass - - -@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .create_from_symmetric_key()") -class TestProvisioningClientCreateFromSymmetricKey( - ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromSymmetricKeyTests -): - pass - - -@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .create_from_x509_certificate()") -class TestProvisioningClientCreateFromX509Certificate( - ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromX509CertificateTests -): - pass - - -@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .register()") -class TestClientRegister(object): - @pytest.mark.it("Implicitly enables responses from provisioning service if not already enabled") - def test_enables_provisioning_only_if_not_already_enabled( - self, mocker, provisioning_pipeline, registration_result - ): - # Override callback to pass successful result - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - - provisioning_pipeline.responses_enabled.__getitem__.return_value = False - - # assert provisioning_pipeline.responses_enabled is False - client = ProvisioningDeviceClient(provisioning_pipeline) - client.register() - - assert provisioning_pipeline.enable_responses.call_count == 1 - - provisioning_pipeline.enable_responses.reset_mock() - - provisioning_pipeline.responses_enabled.__getitem__.return_value = True - client.register() - assert provisioning_pipeline.enable_responses.call_count == 0 - - @pytest.mark.it("Begins a 'register' pipeline operation") - def test_register_calls_pipeline_register( - self, provisioning_pipeline, mocker, registration_result - ): - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - client = ProvisioningDeviceClient(provisioning_pipeline) - client.register() - assert provisioning_pipeline.register.call_count == 1 - - @pytest.mark.it( - "Begins a 'shutdown' pipeline operation if the registration result is successful" - ) - def test_shutdown_upon_success(self, mocker, provisioning_pipeline, registration_result): - # success result - registration_result._status = "assigned" - - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - client.register() - - assert provisioning_pipeline.shutdown.call_count == 1 - - @pytest.mark.it( - "Does NOT begin a 'shutdown' pipeline operation if the registration result is NOT successful" - ) - def test_no_shutdown_upon_fail(self, mocker, provisioning_pipeline, registration_result): - # fail result - registration_result._status = "not assigned" - - def register_complete_fail_callback(payload, callback): - callback(result=registration_result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_fail_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - client.register() - - assert provisioning_pipeline.shutdown.call_count == 0 - - @pytest.mark.it( - "Waits for the completion of both the 'register' and 'shutdown' pipeline operations before returning, if the registration result is successful" - ) - def test_waits_for_pipeline_op_completions_on_success( - self, mocker, provisioning_pipeline, registration_result - ): - # success result - registration_result._status = "assigned" - - # Set up mocks - cb_mock_register = mocker.MagicMock() - cb_mock_register.wait_for_completion.return_value = registration_result - cb_mock_shutdown = mocker.MagicMock() - mocker.patch( - "azure.iot.device.provisioning.provisioning_device_client.EventedCallback" - ).side_effect = [cb_mock_register, cb_mock_shutdown] - - # Run test - client = ProvisioningDeviceClient(provisioning_pipeline) - client.register() - - # Calls made as expected - assert provisioning_pipeline.register.call_count == 1 - assert provisioning_pipeline.shutdown.call_count == 1 - # Callbacks sent to pipeline as expected - assert provisioning_pipeline.register.call_args == mocker.call( - payload=mocker.ANY, callback=cb_mock_register - ) - assert provisioning_pipeline.shutdown.call_args == mocker.call(callback=cb_mock_shutdown) - # Callback completions were waited upon as expected - assert cb_mock_register.wait_for_completion.call_count == 1 - assert cb_mock_shutdown.wait_for_completion.call_count == 1 - - @pytest.mark.it( - "Waits for the completion of just the 'register' pipeline operation before returning, if the registration result is NOT successful" - ) - def test_waits_for_pipeline_op_completion_on_failure( - self, mocker, provisioning_pipeline, registration_result - ): - # fail result - registration_result._status = "not assigned" - - # Set up mocks - cb_mock_register = mocker.MagicMock() - cb_mock_register.wait_for_completion.return_value = registration_result - cb_mock_shutdown = mocker.MagicMock() - mocker.patch( - "azure.iot.device.provisioning.provisioning_device_client.EventedCallback" - ).side_effect = [cb_mock_register, cb_mock_shutdown] - - # Run test - client = ProvisioningDeviceClient(provisioning_pipeline) - client.register() - - # Calls made as expected - assert provisioning_pipeline.register.call_count == 1 - assert provisioning_pipeline.shutdown.call_count == 0 - # Callbacks sent to pipeline as expected - assert provisioning_pipeline.register.call_args == mocker.call( - payload=mocker.ANY, callback=cb_mock_register - ) - # Callback completions were waited upon as expected - assert cb_mock_register.wait_for_completion.call_count == 1 - assert cb_mock_shutdown.wait_for_completion.call_count == 0 - - @pytest.mark.it("Returns the registration result that the pipeline returned") - def test_verifies_registration_result_returned( - self, mocker, provisioning_pipeline, registration_result - ): - result = registration_result - - def register_complete_success_callback(payload, callback): - callback(result=result) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - result_returned = client.register() - assert result_returned == result - - @pytest.mark.it( - "Raises a client error if the `register` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - pytest.param( - pipeline_exceptions.ConnectionDroppedError, - client_exceptions.ConnectionDroppedError, - id="ConnectionDroppedError->ConnectionDroppedError", - ), - pytest.param( - pipeline_exceptions.ConnectionFailedError, - client_exceptions.ConnectionFailedError, - id="ConnectionFailedError->ConnectionFailedError", - ), - pytest.param( - pipeline_exceptions.UnauthorizedError, - client_exceptions.CredentialError, - id="UnauthorizedError->CredentialError", - ), - pytest.param( - pipeline_exceptions.ProtocolClientError, - client_exceptions.ClientError, - id="ProtocolClientError->ClientError", - ), - pytest.param( - pipeline_exceptions.OperationTimeout, - client_exceptions.OperationTimeout, - id="OperationTimeout->OperationTimeout", - ), - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), - ], - ) - def test_raises_error_on_register_pipeline_op_error( - self, mocker, pipeline_error, client_error, provisioning_pipeline - ): - error = pipeline_error() - - def register_complete_failure_callback(payload, callback): - callback(result=None, error=error) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_failure_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - with pytest.raises(client_error) as e_info: - client.register() - - assert e_info.value.__cause__ is error - assert provisioning_pipeline.register.call_count == 1 - - @pytest.mark.it( - "Raises a client error if the `shutdown` pipeline operation calls back with a pipeline error" - ) - @pytest.mark.parametrize( - "pipeline_error,client_error", - [ - # The only expected errors are unexpected ones - pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError") - ], - ) - def test_raises_error_on_shutdown_pipeline_op_error( - self, mocker, pipeline_error, client_error, provisioning_pipeline, registration_result - ): - # success result is required to trigger shutdown - registration_result._status = "assigned" - - error = pipeline_error() - - def register_complete_success_callback(payload, callback): - callback(result=registration_result) - - def shutdown_failure_callback(callback): - callback(result=None, error=error) - - mocker.patch.object( - provisioning_pipeline, "register", side_effect=register_complete_success_callback - ) - mocker.patch.object( - provisioning_pipeline, "shutdown", side_effect=shutdown_failure_callback - ) - - client = ProvisioningDeviceClient(provisioning_pipeline) - with pytest.raises(client_error) as e_info: - client.register() - - assert e_info.value.__cause__ is error - assert provisioning_pipeline.register.call_count == 1 - - -@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .set_provisioning_payload()") -class TestClientProvisioningPayload(object): - @pytest.mark.it("Sets the payload on the provisioning payload attribute") - @pytest.mark.parametrize( - "payload_input", - [ - pytest.param("Hello World", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - def test_set_payload(self, mocker, payload_input): - provisioning_pipeline = mocker.MagicMock() - - client = ProvisioningDeviceClient(provisioning_pipeline) - client.provisioning_payload = payload_input - assert client._provisioning_payload == payload_input - - @pytest.mark.it("Gets the payload from the provisioning payload property") - @pytest.mark.parametrize( - "payload_input", - [ - pytest.param("Hello World", id="String input"), - pytest.param(222, id="Integer input"), - pytest.param(object(), id="Object input"), - pytest.param(None, id="None input"), - pytest.param([1, "str"], id="List input"), - pytest.param({"a": 2}, id="Dictionary input"), - ], - ) - def test_get_payload(self, mocker, payload_input): - provisioning_pipeline = mocker.MagicMock() - - client = ProvisioningDeviceClient(provisioning_pipeline) - client.provisioning_payload = payload_input - assert client.provisioning_payload == payload_input diff --git a/tests/unit/simple_e2e.py b/tests/unit/simple_e2e.py new file mode 100644 index 000000000..d1f797a1a --- /dev/null +++ b/tests/unit/simple_e2e.py @@ -0,0 +1,216 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module is not currently run in the gate. It's just a simple local E2E test. +Will be removed once full E2E support implemented""" + +import asyncio +import logging +import os +import pytest +from dev_utils import iptables +from azure.iot.device import mqtt_client +from dev_utils import mqtt_helper + +logger = logging.getLogger(__name__) + +PORT = 8883 +TRANSPORT = "tcp" # websockets +IPTABLES_TRANSPORT = "mqtt" # mqttws + +CONNECTION_STRING = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") +HOSTNAME = mqtt_helper.get_hostname(CONNECTION_STRING) + + +class Dropper: + def __init__(self, transport): + self.transport = transport + + def disconnect_outgoing(self, disconnect_type): + iptables.disconnect_output_port(disconnect_type, self.transport, HOSTNAME) + + def drop_outgoing(self): + iptables.disconnect_output_port("DROP", self.transport, HOSTNAME) + + def reject_outgoing(self): + iptables.disconnect_output_port("REJECT", self.transport, HOSTNAME) + + def restore_all(self): + iptables.reconnect_all(self.transport, HOSTNAME) + + +@pytest.fixture(scope="function") +def dropper(): + dropper = Dropper(IPTABLES_TRANSPORT) + yield dropper + logger.info("restoring all") + dropper.restore_all() + + +@pytest.fixture +async def client(): + client_id = mqtt_helper.get_client_id(CONNECTION_STRING) + username = mqtt_helper.get_username(CONNECTION_STRING) + password = mqtt_helper.get_password(CONNECTION_STRING) + ssl_context = mqtt_helper.create_ssl_context() + + client = mqtt_client.MQTTClient( + client_id=client_id, + hostname=HOSTNAME, + port=PORT, + transport=TRANSPORT, + keep_alive=5, + auto_reconnect=False, + ssl_context=ssl_context, + ) + client.set_credentials(username, password) + yield client + await client.disconnect() + + +def assert_connected_state(client): + assert client.is_connected() + assert client._desire_connection + assert not client._network_loop.done() + + +def assert_disconnected_state(client): + assert not client.is_connected() + assert not client._desire_connection + assert client._network_loop is None + + +def assert_dropped_conn_state(client): + assert not client.is_connected() + assert client._desire_connection + assert client._network_loop.done() + + +@pytest.mark.it("Connect and disconnect") +async def test_connect_disconnect_twice(client): + async def conn_disconn(): + # Connect + await client.connect() + assert_connected_state(client) + # Wait + await asyncio.sleep(1) + # Disconnect + await client.disconnect() + assert_disconnected_state(client) + + # Do it twice to make sure it's repeatable + await conn_disconn() + await asyncio.sleep(1) + await conn_disconn() + + +@pytest.mark.it("Queued connects and disconnects") +async def test_queued_connects_and_disconnects(client): + # TODO: this may be unreliable - there's no guarantee that they will resolve in the desired + # order, and thus, the assertion at the end may end up being incorrect. + # This test likely ought to be redesigned. + await asyncio.gather( + client.connect(), + client.disconnect(), + client.disconnect(), + client.connect(), + client.connect(), + client.disconnect(), + ) + assert_disconnected_state(client) + + +@pytest.mark.it("Connection drop") +async def test_connection_drop(client, dropper): + await client.connect() + assert_connected_state(client) + # Wait + await asyncio.sleep(1) + # Drop network + dropper.drop_outgoing() + # Wait for drop + async with client.disconnected_cond: + await client.disconnected_cond.wait() + await asyncio.sleep(0.1) + assert_dropped_conn_state(client) + + +@pytest.mark.it("Connect while connected") +async def test_connect_while_connected(client): + await client.connect() + assert_connected_state(client) + await client.connect() + assert_connected_state(client) + + +@pytest.mark.it("Disconnect while never connected") +async def test_disconnect_while_never_connected(client): + await client.disconnect() + assert_disconnected_state(client) + + +@pytest.mark.it("Disconnect while disconnected") +async def test_disconnect_while_disconnected(client): + # Connect first to disconnect to have been at one point connected + await client.connect() + assert_connected_state(client) + await client.disconnect() + assert_disconnected_state(client) + await client.disconnect() + assert_disconnected_state(client) + + +@pytest.mark.it("Disconnect after drop") +async def test_disconnect_after_drop(client, dropper): + await client.connect() + assert_connected_state(client) + # Wait + await asyncio.sleep(1) + # Drop network + dropper.drop_outgoing() + # Wait for drop + async with client.disconnected_cond: + await client.disconnected_cond.wait() + await asyncio.sleep(0.1) + assert_dropped_conn_state(client) + # Restore and manually disconnect + dropper.restore_all() + await client.disconnect() + assert_disconnected_state(client) + + +@pytest.mark.it("Connect after drop") +async def test_connect_after_drop(client, dropper): + await client.connect() + assert_connected_state(client) + # Wait + await asyncio.sleep(1) + # Drop network + dropper.drop_outgoing() + # Wait for drop + async with client.disconnected_cond: + await client.disconnected_cond.wait() + await asyncio.sleep(0.1) + assert_dropped_conn_state(client) + # Restore and connect manually + dropper.restore_all() + await client.connect() + assert_connected_state(client) + # Wait + await asyncio.sleep(1) + # Drop network again + dropper.drop_outgoing() + # Wait for drop + async with client.disconnected_cond: + await client.disconnected_cond.wait() + await asyncio.sleep(0.1) + assert_dropped_conn_state(client) + # Restore and manually disconnect + dropper.restore_all() + await client.disconnect() + assert_disconnected_state(client) + + +# TODO: auto reconnect diff --git a/tests/unit/common/conftest.py b/tests/unit/test_config.py similarity index 62% rename from tests/unit/common/conftest.py rename to tests/unit/test_config.py index 5fcb9d2b9..b76351ea2 100644 --- a/tests/unit/common/conftest.py +++ b/tests/unit/test_config.py @@ -4,9 +4,6 @@ # license information. # -------------------------------------------------------------------------- -import pytest - - -@pytest.fixture -def fake_return_arg_value(): - return "__fake_return_arg_value__" +"""This is a placeholder file. It will likely not be necessary if the config objects are converted +into TypeDicts as is planned.""" +# TODO: Remove if no longer necessary, or complete diff --git a/tests/unit/common/auth/test_connection_string.py b/tests/unit/test_connection_string.py similarity index 90% rename from tests/unit/common/auth/test_connection_string.py rename to tests/unit/test_connection_string.py index 1d6934826..99ce10281 100644 --- a/tests/unit/common/auth/test_connection_string.py +++ b/tests/unit/test_connection_string.py @@ -6,10 +6,12 @@ import pytest import logging -from azure.iot.device.common.auth.connection_string import ConnectionString +from azure.iot.device.connection_string import ConnectionString logging.basicConfig(level=logging.DEBUG) +# TODO: eliminate refernces to service connection string + @pytest.mark.describe("ConnectionString") class TestConnectionString(object): @@ -67,7 +69,7 @@ def test_instantiates_correctly_from_string(self, input_string): id="Duplicate key", ), pytest.param( - "HostName=my.host.name;DeviceId=my-device;ModuleId=my-module;SharedAccessKey=mykeyname;x509=True", + "HostName=my.host.name;DeviceId=my-device;ModuleId=my-module;SharedAccessKey=mykeyname;x509=true", id="Mixed authentication scheme", ), ], @@ -115,6 +117,18 @@ def test_indexing_key_raises_key_error_if_key_not_in_string(self): ) cs["SharedAccessSignature"] + @pytest.mark.it( + "Supports the 'in' operator for validating if a key is contained in the ConnectionString" + ) + def test_item_in_string(self): + cs = ConnectionString( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + ) + assert "SharedAccessKey" in cs + assert "SharedAccessKeyName" in cs + assert "HostName" in cs + assert "FakeKeyNotInTheString" not in cs + @pytest.mark.describe("ConnectionString - .get()") class TestConnectionStringGet(object): diff --git a/tests/unit/iothub/test_edge_hsm.py b/tests/unit/test_edge_hsm.py similarity index 65% rename from tests/unit/iothub/test_edge_hsm.py rename to tests/unit/test_edge_hsm.py index 8472ad11f..daa9e6e9f 100644 --- a/tests/unit/iothub/test_edge_hsm.py +++ b/tests/unit/test_edge_hsm.py @@ -3,20 +3,16 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - +import base64 +import json import pytest -import logging import requests -import json -import base64 -import urllib -from azure.iot.device.iothub.edge_hsm import IoTEdgeHsm, IoTEdgeError +import urllib.parse +from azure.iot.device.edge_hsm import IoTEdgeHsm +from azure.iot.device import exceptions as exc from azure.iot.device import user_agent -logging.basicConfig(level=logging.DEBUG) - - @pytest.fixture def edge_hsm(): return IoTEdgeHsm( @@ -108,76 +104,78 @@ def test_set_api_version(self): @pytest.mark.describe("IoTEdgeHsm - .get_certificate()") class TestIoTEdgeHsmGetCertificate(object): + @pytest.fixture(autouse=True) + def mock_requests_get(self, mocker): + return mocker.patch.object(requests, "get") + @pytest.mark.it("Sends an HTTP GET request to retrieve the trust bundle from Edge") - def test_requests_trust_bundle(self, mocker, edge_hsm): - mock_request_get = mocker.patch.object(requests, "get") + async def test_requests_trust_bundle(self, mocker, edge_hsm, mock_requests_get): expected_url = edge_hsm.workload_uri + "trust-bundle" expected_params = {"api-version": edge_hsm.api_version} expected_headers = { "User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent()) } - edge_hsm.get_certificate() + await edge_hsm.get_certificate() - assert mock_request_get.call_count == 1 - assert mock_request_get.call_args == mocker.call( + assert mock_requests_get.call_count == 1 + assert mock_requests_get.call_args == mocker.call( expected_url, params=expected_params, headers=expected_headers ) @pytest.mark.it("Returns the certificate from the trust bundle received from Edge") - def test_returns_certificate(self, mocker, edge_hsm): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value + async def test_returns_certificate(self, edge_hsm, mock_requests_get): + mock_response = mock_requests_get.return_value certificate = "my certificate" mock_response.json.return_value = {"certificate": certificate} - returned_cert = edge_hsm.get_certificate() + returned_cert = await edge_hsm.get_certificate() assert returned_cert is certificate @pytest.mark.it("Raises IoTEdgeError if a bad request is made to Edge") - def test_bad_request(self, mocker, edge_hsm): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value + async def test_bad_request(self, edge_hsm, mock_requests_get): + mock_response = mock_requests_get.return_value error = requests.exceptions.HTTPError() mock_response.raise_for_status.side_effect = error - with pytest.raises(IoTEdgeError) as e_info: - edge_hsm.get_certificate() + with pytest.raises(exc.IoTEdgeError) as e_info: + await edge_hsm.get_certificate() assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle") - def test_bad_json(self, mocker, edge_hsm): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value + async def test_bad_json(self, edge_hsm, mock_requests_get): + mock_response = mock_requests_get.return_value error = ValueError() mock_response.json.side_effect = error - with pytest.raises(IoTEdgeError) as e_info: - edge_hsm.get_certificate() + with pytest.raises(exc.IoTEdgeError) as e_info: + await edge_hsm.get_certificate() assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle") - def test_bad_trust_bundle(self, mocker, edge_hsm): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value + async def test_bad_trust_bundle(self, edge_hsm, mock_requests_get): + mock_response = mock_requests_get.return_value # Return an empty json dict with no 'certificate' key mock_response.json.return_value = {} - with pytest.raises(IoTEdgeError): - edge_hsm.get_certificate() + with pytest.raises(exc.IoTEdgeError): + await edge_hsm.get_certificate() @pytest.mark.describe("IoTEdgeHsm - .sign()") class TestIoTEdgeHsmSign(object): + @pytest.fixture(autouse=True) + def mock_requests_post(self, mocker): + return mocker.patch.object(requests, "post") + @pytest.mark.it( "Makes an HTTP request to Edge to sign a piece of string data using the HMAC-SHA256 algorithm" ) - def test_requests_data_signing(self, mocker, edge_hsm): + async def test_requests_data_signing(self, mocker, edge_hsm, mock_requests_post): data_str = "somedata" data_str_b64 = "c29tZWRhdGE=" - mock_request_post = mocker.patch.object(requests, "post") - mock_request_post.return_value.json.return_value = {"digest": "somedigest"} + mock_requests_post.return_value.json.return_value = {"digest": "somedigest"} expected_url = "{workload_uri}modules/{module_id}/genid/{generation_id}/sign".format( workload_uri=edge_hsm.workload_uri, module_id=edge_hsm.module_id, @@ -189,65 +187,77 @@ def test_requests_data_signing(self, mocker, edge_hsm): } expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64}) - edge_hsm.sign(data_str) + await edge_hsm.sign(data_str) - assert mock_request_post.call_count == 1 - assert mock_request_post.call_args == mocker.call( + assert mock_requests_post.call_count == 1 + assert mock_requests_post.call_args == mocker.call( url=expected_url, params=expected_params, headers=expected_headers, data=expected_json ) @pytest.mark.it("Base64 encodes the string data in the request") - def test_b64_encodes_data(self, mocker, edge_hsm): + async def test_b64_encodes_data(self, edge_hsm, mock_requests_post): # This test is actually implicitly tested in the first test, but it's # important to have an explicit test for it since it's a requirement data_str = "somedata" data_str_b64 = base64.b64encode(data_str.encode("utf-8")).decode() - mock_request_post = mocker.patch.object(requests, "post") - mock_request_post.return_value.json.return_value = {"digest": "somedigest"} + mock_requests_post.return_value.json.return_value = {"digest": "somedigest"} - edge_hsm.sign(data_str) + await edge_hsm.sign(data_str) - sent_data = json.loads(mock_request_post.call_args[1]["data"])["data"] + sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"] assert data_str != data_str_b64 assert sent_data == data_str_b64 @pytest.mark.it("Returns the signed data received from Edge") - def test_returns_signed_data(self, mocker, edge_hsm): + async def test_returns_signed_data(self, edge_hsm, mock_requests_post): expected_digest = "somedigest" - mock_request_post = mocker.patch.object(requests, "post") - mock_request_post.return_value.json.return_value = {"digest": expected_digest} + mock_requests_post.return_value.json.return_value = {"digest": expected_digest} - signed_data = edge_hsm.sign("somedata") + signed_data = await edge_hsm.sign("somedata") assert signed_data == expected_digest + @pytest.mark.it("Supports data strings in both string and byte formats") + @pytest.mark.parametrize( + "data_string, expected_request_data", + [ + pytest.param("sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="String"), + pytest.param(b"sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="Bytes"), + ], + ) + async def test_supported_types( + self, edge_hsm, data_string, expected_request_data, mock_requests_post + ): + mock_requests_post.return_value.json.return_value = {"digest": "somedigest"} + await edge_hsm.sign(data_string) + sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"] + + assert sent_data == expected_request_data + @pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub") - def test_bad_request(self, mocker, edge_hsm): - mock_request_post = mocker.patch.object(requests, "post") - mock_response = mock_request_post.return_value + async def test_bad_request(self, edge_hsm, mock_requests_post): + mock_response = mock_requests_post.return_value error = requests.exceptions.HTTPError() mock_response.raise_for_status.side_effect = error - with pytest.raises(IoTEdgeError) as e_info: - edge_hsm.sign("somedata") + with pytest.raises(exc.IoTEdgeError) as e_info: + await edge_hsm.sign("somedata") assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response") - def test_bad_json(self, mocker, edge_hsm): - mock_request_post = mocker.patch.object(requests, "post") - mock_response = mock_request_post.return_value + async def test_bad_json(self, edge_hsm, mock_requests_post): + mock_response = mock_requests_post.return_value error = ValueError() mock_response.json.side_effect = error - with pytest.raises(IoTEdgeError) as e_info: - edge_hsm.sign("somedata") + with pytest.raises(exc.IoTEdgeError) as e_info: + await edge_hsm.sign("somedata") assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response") - def test_bad_response(self, mocker, edge_hsm): - mock_request_post = mocker.patch.object(requests, "post") - mock_response = mock_request_post.return_value + async def test_bad_response(self, edge_hsm, mock_requests_post): + mock_response = mock_requests_post.return_value mock_response.json.return_value = {} - with pytest.raises(IoTEdgeError): - edge_hsm.sign("somedata") + with pytest.raises(exc.IoTEdgeError): + await edge_hsm.sign("somedata") diff --git a/tests/unit/iothub/pipeline/test_http_path_iothub.py b/tests/unit/test_http_path_iothub.py similarity index 56% rename from tests/unit/iothub/pipeline/test_http_path_iothub.py rename to tests/unit/test_http_path_iothub.py index 25a2aa649..f7be0281c 100644 --- a/tests/unit/iothub/pipeline/test_http_path_iothub.py +++ b/tests/unit/test_http_path_iothub.py @@ -5,7 +5,7 @@ # -------------------------------------------------------------------------- import pytest import logging -from azure.iot.device.iothub.pipeline import http_path_iothub +from azure.iot.device import http_path_iothub logging.basicConfig(level=logging.DEBUG) @@ -14,75 +14,77 @@ # make sure any URL encoded value can encode a '+' specifically, in addition to regular encoding. -@pytest.mark.describe(".get_method_invoke_path()") +@pytest.mark.describe(".get_direct_method_invoke_path()") class TestGetMethodInvokePath(object): - @pytest.mark.it("Returns the method invoke HTTP path") + @pytest.mark.it("Returns the relative method invoke HTTP path") @pytest.mark.parametrize( "device_id, module_id, expected_path", [ pytest.param( "my_device", None, - "twins/my_device/methods", - id="'my_device' ==> 'twins/my_device/methods'", + "/twins/my_device/methods", + id="'my_device' ==> '/twins/my_device/methods'", ), pytest.param( "my/device", None, - "twins/my%2Fdevice/methods", - id="'my/device' ==> 'twins/my%2Fdevice/methods'", + "/twins/my%2Fdevice/methods", + id="'my/device' ==> '/twins/my%2Fdevice/methods'", ), pytest.param( "my+device", None, - "twins/my%2Bdevice/methods", - id="'my+device' ==> 'twins/my%2Bdevice/methods'", + "/twins/my%2Bdevice/methods", + id="'my+device' ==> '/twins/my%2Bdevice/methods'", ), pytest.param( "my_device", "my_module", - "twins/my_device/modules/my_module/methods", - id="('my_device', 'my_module') ==> 'twins/my_device/modules/my_module/methods'", + "/twins/my_device/modules/my_module/methods", + id="('my_device', 'my_module') ==> '/twins/my_device/modules/my_module/methods'", ), pytest.param( "my/device", "my?module", - "twins/my%2Fdevice/modules/my%3Fmodule/methods", - id="('my/device', 'my?module') ==> 'twins/my%2Fdevice/modules/my%3Fmodule/methods'", + "/twins/my%2Fdevice/modules/my%3Fmodule/methods", + id="('my/device', 'my?module') ==> '/twins/my%2Fdevice/modules/my%3Fmodule/methods'", ), pytest.param( "my+device", "my+module", - "twins/my%2Bdevice/modules/my%2Bmodule/methods", - id="('my+device', 'my+module') ==> 'twins/my%2Bdevice/modules/my%2Bmodule/methods'", + "/twins/my%2Bdevice/modules/my%2Bmodule/methods", + id="('my+device', 'my+module') ==> '/twins/my%2Bdevice/modules/my%2Bmodule/methods'", ), ], ) def test_path(self, device_id, module_id, expected_path): - path = http_path_iothub.get_method_invoke_path(device_id=device_id, module_id=module_id) + path = http_path_iothub.get_direct_method_invoke_path( + device_id=device_id, module_id=module_id + ) assert path == expected_path @pytest.mark.describe(".get_storage_info_for_blob_path()") class TestGetStorageInfoPath(object): - @pytest.mark.it("Returns the storage info HTTP path") + @pytest.mark.it("Returns the relative storage info HTTP path") @pytest.mark.parametrize( "device_id, expected_path", [ pytest.param( "my_device", - "devices/my_device/files", - id="'my_device' ==> 'devices/my_device/files'", + "/devices/my_device/files", + id="'my_device' ==> '/devices/my_device/files'", ), pytest.param( "my/device", - "devices/my%2Fdevice/files", - id="'my/device' ==> 'devices/my%2Fdevice/files'", + "/devices/my%2Fdevice/files", + id="'my/device' ==> '/devices/my%2Fdevice/files'", ), pytest.param( "my+device", - "devices/my%2Bdevice/files", - id="'my+device' ==> 'devices/my%2Bdevice/files'", + "/devices/my%2Bdevice/files", + id="'my+device' ==> '/devices/my%2Bdevice/files'", ), ], ) @@ -93,24 +95,24 @@ def test_path(self, device_id, expected_path): @pytest.mark.describe(".get_notify_blob_upload_status_path()") class TestGetNotifyBlobUploadStatusPath(object): - @pytest.mark.it("Returns the notify blob upload status HTTP path") + @pytest.mark.it("Returns the relative notify blob upload status HTTP path") @pytest.mark.parametrize( "device_id, expected_path", [ pytest.param( "my_device", - "devices/my_device/files/notifications", - id="'my_device' ==> 'devices/my_device/files/notifications'", + "/devices/my_device/files/notifications", + id="'my_device' ==> '/devices/my_device/files/notifications'", ), pytest.param( "my/device", - "devices/my%2Fdevice/files/notifications", - id="'my/device' ==> 'devices/my%2Fdevice/files/notifications'", + "/devices/my%2Fdevice/files/notifications", + id="'my/device' ==> '/devices/my%2Fdevice/files/notifications'", ), pytest.param( "my+device", - "devices/my%2Bdevice/files/notifications", - id="'my+device' ==> 'devices/my%2Bdevice/files/notifications'", + "/devices/my%2Bdevice/files/notifications", + id="'my+device' ==> '/devices/my%2Bdevice/files/notifications'", ), ], ) diff --git a/tests/unit/test_iothub_http_client.py b/tests/unit/test_iothub_http_client.py new file mode 100644 index 000000000..824c60558 --- /dev/null +++ b/tests/unit/test_iothub_http_client.py @@ -0,0 +1,1032 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import aiohttp +import asyncio +import pytest +import ssl +import time +import urllib.parse +from pytest_lazyfixture import lazy_fixture +from dev_utils import custom_mock +from azure.iot.device import config, constant, user_agent +from azure.iot.device import http_path_iothub as http_path +from azure.iot.device import sastoken as st +from azure.iot.device import exceptions as exc +from azure.iot.device.iothub_http_client import IoTHubHTTPClient + +FAKE_DEVICE_ID = "fake_device_id" +FAKE_MODULE_ID = "fake_module_id" +FAKE_HOSTNAME = "fake.hostname" +FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" +FAKE_EXPIRY = str(int(time.time()) + 3600) +FAKE_URI = "fake/resource/location" +FAKE_CORRELATION_ID = ( + "MjAyMzAyMjIwNTQ5XzNjNTM5YTQyLWM1STItNDM3OS1iMzc5LWFiMTlhYTNhZWJjZV9zb21lIGJsb2JfdmVyMi4w" +) + + +# NOTE: We use "async with" statements when using aiohttp to do HTTP requests, i.e. context managers. +# It is not as easy to mock out context managers as regular functions/coroutines, but still quite +# doable, although doing so relies on some implementation knowledge of how context managers work. +# That said, you should be able to follow along fairly easily even without that, so just follow the +# templates here if modifying this file. +# +# All you really need to know is that the HTTP request itself (e.g. POST) is a regular function +# that ends up returning an async context manager, which is then used to do the request in an +# asynchronous fashion. This is why most tests related to the request itself will be checking +# the mock 'calls' rather than 'awaits', with only a few really verifying that the async context +# manager is being used by verifying the 'await' of the `__aenter__` coroutine + + +# ~~~~~ Fixtures ~~~~~ +@pytest.fixture +def sastoken(): + sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( + resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=FAKE_EXPIRY + ) + return st.SasToken(sastoken_str) + + +@pytest.fixture +def mock_sastoken_provider(mocker, sastoken): + provider = mocker.MagicMock(spec=st.SasTokenProvider) + provider.get_current_sastoken.return_value = sastoken + return provider + + +@pytest.fixture(autouse=True) +def mock_session(mocker): + mock_session = mocker.MagicMock(spec=aiohttp.ClientSession) + # Mock out POST and it's response + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.status = 200 + mock_response.reason = "some reason" + return mock_session + + +@pytest.fixture +def client_config(): + """Defaults to Device Configuration. Required values only. + Customize in test if you need specific options (incl. Module)""" + + client_config = config.IoTHubClientConfig( + device_id=FAKE_DEVICE_ID, hostname=FAKE_HOSTNAME, ssl_context=ssl.SSLContext() + ) + return client_config + + +@pytest.fixture +async def client(mocker, client_config, mock_session): + client = IoTHubHTTPClient(client_config) + client._session = mock_session + + yield client + # Shutdown contains a sleep of 250ms, so mock it out to speed up test performance + mocker.patch.object(asyncio, "sleep") + await client.shutdown() + + +# ~~~~~ Saved Parametrizations ~~~~~ +failed_status_codes = [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), +] + +http_post_exceptions = [ + # TODO: are there expected exceptions here? Needs to be manually tested and investigated + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception") +] + +http_response_json_exceptions = [ + # TODO: are there expected exceptions here? Needs to be manually tested and investigated + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception") +] + + +# ~~~~~ Tests ~~~~~ +@pytest.mark.describe("IoTHubHTTPClient -- Instantiation") +class TestIoTHubHTTPClientInstantiation: + # NOTE: As the instantiation is the unit under test here, we shouldn't use the client fixture. + # This means that you must do graceful exit by shutting down the client at the end of all tests + # and you may need to do a manual mock of the underlying HTTP client where appropriate. + configurations = [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ] + + @pytest.fixture(autouse=True) + def mock_asyncio_sleep(self, mocker): + """Mock asyncio sleep for performance so that shutdowns don't have a delay""" + mocker.patch.object(asyncio, "sleep") + + @pytest.mark.it( + "Stores the `device_id` and `module_id` values from the IoTHubClientConfig as attributes" + ) + @pytest.mark.parametrize("device_id, module_id", configurations) + async def test_simple_ids(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + + client = IoTHubHTTPClient(client_config) + assert client._device_id == device_id + assert client._module_id == module_id + + await client.shutdown() + + @pytest.mark.it( + "Derives the `edge_module_id` from the `device_id` and `module_id` if the IoTHubClientConfig contains a `module_id`" + ) + async def test_edge_module_id(self, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = FAKE_MODULE_ID + expected_edge_module_id = "{device_id}/{module_id}".format( + device_id=FAKE_DEVICE_ID, module_id=FAKE_MODULE_ID + ) + + client = IoTHubHTTPClient(client_config) + assert client._edge_module_id == expected_edge_module_id + + await client.shutdown() + + # NOTE: It would be nice if we could only do this for Edge modules, but there's no way to + # indicate a Module is Edge vs non-Edge + @pytest.mark.it("Sets the `edge_module_id` to None if not using a Module") + async def test_no_edge_module_id(self, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None + + client = IoTHubHTTPClient(client_config) + assert client._edge_module_id is None + + await client.shutdown() + + @pytest.mark.it( + "Constructs the `user_agent_string` by concatenating the base IoTHub user agent with the `product_info` from the IoTHubClientConfig" + ) + @pytest.mark.parametrize("device_id, module_id", configurations) + @pytest.mark.parametrize( + "product_info", + [ + pytest.param("", id="No Product Info"), + pytest.param("my-product-info", id="Custom Product Info"), + pytest.param( + constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1", + id="Digital Twin Product Info", + ), + ], + ) + async def test_user_agent(self, client_config, device_id, module_id, product_info): + client_config.device_id = device_id + client_config.module_id = module_id + client_config.product_info = product_info + expected_user_agent = user_agent.get_iothub_user_agent() + product_info + + client = IoTHubHTTPClient(client_config) + assert client._user_agent_string == expected_user_agent + + await client.shutdown() + + @pytest.mark.it("Does not URL encode the user agent string") + @pytest.mark.parametrize("device_id, module_id", configurations) + @pytest.mark.parametrize( + "product_info", + [ + pytest.param("my$product$info", id="Custom Product Info"), + pytest.param( + constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$ensor;1", + id="Digital Twin Product Info", + ), + ], + ) + async def test_user_agent_no_url_encoding( + self, client_config, device_id, module_id, product_info + ): + # NOTE: The user agent DOES eventually get url encoded, just not here, and not yet + client_config.device_id = device_id + client_config.module_id = module_id + client_config.product_info = product_info + expected_user_agent = user_agent.get_iothub_user_agent() + product_info + url_encoded_expected_user_agent = urllib.parse.quote_plus(expected_user_agent) + assert url_encoded_expected_user_agent != expected_user_agent + + client = IoTHubHTTPClient(client_config) + assert client._user_agent_string == expected_user_agent + + await client.shutdown() + + @pytest.mark.it( + "Creates a aiohttp ClientSession configured for accessing a URL based on the IoTHubClientConfig's `hostname`, with a timeout of 10 seconds" + ) + @pytest.mark.parametrize("device_id, module_id", configurations) + async def test_client_session(self, mocker, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + + spy_session_init = mocker.spy(aiohttp, "ClientSession") + expected_base_url = "https://" + client_config.hostname + expected_timeout = 10 + + client = IoTHubHTTPClient(client_config) + assert spy_session_init.call_count == 1 + assert spy_session_init.call_args == mocker.call( + base_url=expected_base_url, timeout=mocker.ANY + ) + timeout_obj = spy_session_init.call_args[1]["timeout"] + assert isinstance(timeout_obj, aiohttp.ClientTimeout) + assert timeout_obj.total == expected_timeout + assert client._session is spy_session_init.spy_return + + await client.shutdown() + + @pytest.mark.it("Stores the `ssl_context` from the IoTHubClientConfig as an attribute") + @pytest.mark.parametrize("device_id, module_id", configurations) + async def test_ssl_context(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + assert client_config.ssl_context is not None + + client = IoTHubHTTPClient(client_config) + assert client._ssl_context is client_config.ssl_context + + await client.shutdown() + + @pytest.mark.it("Stores the `sastoken_provider` from the IoTHubClientConfig as an attribute") + @pytest.mark.parametrize("device_id, module_id", configurations) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="No SasTokenProvider present"), + ], + ) + async def test_sastoken_provider(self, client_config, device_id, module_id, sastoken_provider): + client_config.device_id = device_id + client_config.module_id = module_id + client_config.sastoken_provider = sastoken_provider + + client = IoTHubHTTPClient(client_config) + assert client._sastoken_provider is client_config.sastoken_provider + + await client.shutdown() + + +@pytest.mark.describe("IoTHubHTTPClient - .shutdown()") +class TestIoTHubHTTPClientShutdown: + @pytest.fixture(autouse=True) + def mock_asyncio_sleep(self, mocker): + """Mock asyncio sleep for performance so that shutdowns don't have a delay""" + return mocker.patch.object(asyncio, "sleep") + + @pytest.mark.it("Closes the aiohttp ClientSession") + async def test_close_session(self, mocker, client): + assert client._session.close.await_count == 0 + + await client.shutdown() + assert client._session.close.await_count == 1 + assert client._session.close.await_args == mocker.call() + + @pytest.mark.it("Waits 250ms to allow for proper SSL cleanup") + async def test_wait(self, mocker, client, mock_asyncio_sleep): + assert mock_asyncio_sleep.await_count == 0 + + await client.shutdown() + + assert mock_asyncio_sleep.await_count == 1 + assert mock_asyncio_sleep.await_args == mocker.call(0.25) + + @pytest.mark.it("Does not return a value") + async def test_return_value(self, client): + retval = await client.shutdown() + assert retval is None + + # TODO: Need to show shielding, but the mocking is difficult. Revisit. + @pytest.mark.it("Can be cancelled while waiting for the aiohttp ClientSession to close") + async def test_cancel_during_close(self, client): + original_close = client._session.close + client._session.close = custom_mock.HangingAsyncMock() + try: + t = asyncio.create_task(client.shutdown()) + + # Hanging, waiting for close to finish + await client._session.close.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + finally: + # Restore original close so cleanup works correctly + client._session.close = original_close + + @pytest.mark.it("Can be cancelled while waiting for SSL cleanup") + async def test_cancel_during_wait(self, client): + original_sleep = asyncio.sleep + asyncio.sleep = custom_mock.HangingAsyncMock() + try: + t = asyncio.create_task(client.shutdown()) + + # Hanging, waiting for sleep to finish + await asyncio.sleep.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + finally: + # Restore original sleep to... everything... works correctly + asyncio.sleep = original_sleep + + +@pytest.mark.describe("IoTHubHTTPClient - .invoke_direct_method()") +class TestIoTHubHTTPClientInvokeDirectMethod: + @pytest.fixture(autouse=True) + def modify_client_config(self, client_config): + """Modify the client config to always be an Edge Module""" + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = FAKE_MODULE_ID + + @pytest.fixture(autouse=True) + def modify_post_response(self, client): + fake_method_response = { + "status": 200, + "payload": {"fake": "payload"}, + } + mock_response = client._session.post.return_value.__aenter__.return_value + mock_response.json.return_value = fake_method_response + + @pytest.fixture + def method_params(self): + return { + "methodName": "fake method", + "payload": {"fake": "payload"}, + "connectTimeoutInSeconds": 47, + "responseTimeoutInSeconds": 42, + } + + targets = [ + pytest.param("target_device", None, id="Target: Device"), + pytest.param("target_device", "target_module", id="Target: Module"), + ] + + @pytest.mark.it( + "Does an asynchronous POST request operation to the relative 'direct method invoke' path using the aiohttp ClientSession and the stored SSL context" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_http_post( + self, mocker, client, target_device_id, target_module_id, method_params + ): + post_ctx_manager = client._session.post.return_value + assert client._session.post.call_count == 0 + assert post_ctx_manager.__aenter__.await_count == 0 + expected_path = http_path.get_direct_method_invoke_path(target_device_id, target_module_id) + + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=expected_path, + json=mocker.ANY, + params=mocker.ANY, + headers=mocker.ANY, + ssl=client._ssl_context, + ) + assert post_ctx_manager.__aenter__.await_count == 1 + assert post_ctx_manager.__aenter__.await_args == mocker.call() + + @pytest.mark.it( + "Sends the provided method parameters with the POST request as the JSON payload" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_post_json( + self, mocker, client, target_device_id, target_module_id, method_params + ): + assert client._session.post.call_count == 0 + + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=mocker.ANY, + json=method_params, + params=mocker.ANY, + headers=mocker.ANY, + ssl=mocker.ANY, + ) + + @pytest.mark.it("Sends the API version with the POST request as a query parameter") + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_post_query_params( + self, mocker, client, target_device_id, target_module_id, method_params + ): + assert client._session.post.call_count == 0 + expected_params = {"api-version": constant.IOTHUB_API_VERSION} + + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=mocker.ANY, + json=mocker.ANY, + params=expected_params, + headers=mocker.ANY, + ssl=mocker.ANY, + ) + + @pytest.mark.it( + "Sets the 'User-Agent' HTTP header on the POST request to the URL-encoded `user_agent` value stored on the client" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_post_user_agent_header( + self, client, target_device_id, target_module_id, method_params + ): + assert client._session.post.call_count == 0 + expected_user_agent = urllib.parse.quote_plus(client._user_agent_string) + assert expected_user_agent != client._user_agent_string + + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert headers["User-Agent"] == expected_user_agent + + @pytest.mark.it( + "Sets the 'x-ms-edge-moduleId' HTTP header on the POST request to the `edge_module_id` value stored on the client" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_post_edge_module_id_header( + self, client, target_device_id, target_module_id, method_params + ): + assert client._session.post.call_count == 0 + + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert headers["x-ms-edge-moduleId"] == client._edge_module_id + + @pytest.mark.it( + "Sets the 'Authorization' HTTP header on the POST request to the current SAS Token string from the SasTokenProvider stored on the client, if it exists" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_post_authorization_header_sas( + self, client, target_device_id, target_module_id, method_params, mock_sastoken_provider + ): + assert client._session.post.call_count == 0 + client._sastoken_provider = mock_sastoken_provider + assert mock_sastoken_provider.get_current_sastoken.call_count == 0 + expected_sastoken_string = str(mock_sastoken_provider.get_current_sastoken.return_value) + + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert headers["Authorization"] == expected_sastoken_string + + @pytest.mark.it( + "Does not include an 'Authorization' HTTP header on the POST request if not using SAS Token authentication" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_post_authorization_header_no_sas( + self, client, target_device_id, target_module_id, method_params + ): + assert client._session.post.call_count == 0 + assert client._sastoken_provider is None + + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert "Authorization" not in headers + + @pytest.mark.it( + "Fetches and returns the JSON payload of the HTTP response, if the HTTP request was successful" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_returns_json_payload( + self, client, target_device_id, target_module_id, method_params + ): + mock_response = client._session.post.return_value.__aenter__.return_value + assert mock_response.status == 200 + + method_response = await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + assert method_response is mock_response.json.return_value + + @pytest.mark.it( + "Raises an IoTEdgeError if a HTTP response is received with a failed status code" + ) + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + @pytest.mark.parametrize("failed_status", failed_status_codes) + async def test_failed_response( + self, client, target_device_id, target_module_id, method_params, failed_status + ): + mock_response = client._session.post.return_value.__aenter__.return_value + mock_response.status = failed_status + + with pytest.raises(exc.IoTEdgeError): + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + # NOTE: It'd be really great if we could reject non-Edge modules, but we can't. + @pytest.mark.it("Raises IoTHubClientError if not configured as a Module") + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_not_edge(self, client, target_device_id, target_module_id, method_params): + client._module_id = None + client._edge_module_id = None + + with pytest.raises(exc.IoTHubClientError): + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + + @pytest.mark.it("Allows any exceptions raised by the POST request to propagate") + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + @pytest.mark.parametrize("exception", http_post_exceptions) + async def test_http_post_raises( + self, client, target_device_id, target_module_id, method_params, exception + ): + client._session.post.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised while getting the JSON response to propagate") + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + @pytest.mark.parametrize("exception", http_response_json_exceptions) + async def test_http_response_json_raises( + self, client, target_device_id, target_module_id, method_params, exception + ): + mock_response = client._session.post.return_value.__aenter__.return_value + mock_response.json.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the HTTP response") + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_cancel_during_request( + self, client, target_device_id, target_module_id, method_params + ): + post_ctx_manager = client._session.post.return_value + post_ctx_manager.__aenter__ = custom_mock.HangingAsyncMock() + + t = asyncio.create_task( + client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + ) + + # Hanging, waiting for response + await post_ctx_manager.__aenter__.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it("Can be cancelled while fetching the payload of the HTTP response") + @pytest.mark.parametrize("target_device_id, target_module_id", targets) + async def test_cancel_during_payload_fetch( + self, client, target_device_id, target_module_id, method_params + ): + response = client._session.post.return_value.__aenter__.return_value + response.json = custom_mock.HangingAsyncMock() + + t = asyncio.create_task( + client.invoke_direct_method( + device_id=target_device_id, module_id=target_module_id, method_params=method_params + ) + ) + + # Hanging, waiting to fetch JSON + await response.json.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubHTTPClient - .get_storage_info_for_blob") +class TestIoTHubHTTPClientGetStorageInfoForBlob: + @pytest.fixture(autouse=True) + def modify_client_config(self, client_config): + """Modify the client config to always be a Device""" + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None + + @pytest.fixture(autouse=True) + def modify_post_response(self, client): + fake_storage_info = { + "correlationId": FAKE_CORRELATION_ID, + "hostName": "fakeblobstorage.blob.core.windows.net", + "containerName": "fakeblobcontainer", + "blobName": "fake_device_id/fake_blob", + "sasToken": "?sv=2018-03-28&sr=b&sig=9x00K4bgLhiif0mVPTXRL8axz4yPG32LvnpVhwW4IfQ%3D&se=2023-02-22T05%3A39%3A49Z&sp=rw", + } + mock_response = client._session.post.return_value.__aenter__.return_value + mock_response.json.return_value = fake_storage_info + + @pytest.mark.it( + "Does an asynchronous POST request operation to the relative 'get storage info' path using the aiohttp ClientSession and the stored SSL context" + ) + async def test_http_post(self, mocker, client): + post_ctx_manager = client._session.post.return_value + assert client._session.post.call_count == 0 + assert post_ctx_manager.__aenter__.await_count == 0 + expected_path = http_path.get_storage_info_for_blob_path(client._device_id) + + await client.get_storage_info_for_blob(blob_name="fake_blob") + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=expected_path, + json=mocker.ANY, + params=mocker.ANY, + headers=mocker.ANY, + ssl=client._ssl_context, + ) + assert post_ctx_manager.__aenter__.await_count == 1 + assert post_ctx_manager.__aenter__.await_args == mocker.call() + + @pytest.mark.it("Sends the provided `blob_name` with the POST request inside a JSON payload") + async def test_post_json(self, mocker, client): + assert client._session.post.call_count == 0 + expected_json = {"blobName": "fake_blob"} + + await client.get_storage_info_for_blob(blob_name="fake_blob") + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=mocker.ANY, + json=expected_json, + params=mocker.ANY, + headers=mocker.ANY, + ssl=mocker.ANY, + ) + + @pytest.mark.it("Sends the API version with the POST request as a query parameter") + async def test_post_query_params(self, mocker, client): + assert client._session.post.call_count == 0 + expected_params = {"api-version": constant.IOTHUB_API_VERSION} + + await client.get_storage_info_for_blob(blob_name="fake_blob") + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=mocker.ANY, + json=mocker.ANY, + params=expected_params, + headers=mocker.ANY, + ssl=mocker.ANY, + ) + + @pytest.mark.it( + "Sets the 'User-Agent' HTTP header on the POST request to the URL-encoded `user_agent` value stored on the client" + ) + async def test_post_user_agent_header(self, client): + assert client._session.post.call_count == 0 + expected_user_agent = urllib.parse.quote_plus(client._user_agent_string) + assert expected_user_agent != client._user_agent_string + + await client.get_storage_info_for_blob(blob_name="fake_blob") + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert headers["User-Agent"] == expected_user_agent + + @pytest.mark.it( + "Sets the 'Authorization' HTTP header on the POST request to the current SAS Token string from the SasTokenProvider stored on the client, if it exists" + ) + async def test_post_authorization_header_sas(self, client, mock_sastoken_provider): + assert client._session.post.call_count == 0 + client._sastoken_provider = mock_sastoken_provider + assert mock_sastoken_provider.get_current_sastoken.call_count == 0 + expected_sastoken_string = str(mock_sastoken_provider.get_current_sastoken.return_value) + + await client.get_storage_info_for_blob(blob_name="fake_blob") + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert headers["Authorization"] == expected_sastoken_string + + @pytest.mark.it( + "Does not include an 'Authorization' HTTP header on the POST request if not using SAS Token authentication" + ) + async def test_post_authorization_header_no_sas(self, client): + assert client._session.post.call_count == 0 + assert client._sastoken_provider is None + + await client.get_storage_info_for_blob(blob_name="fake_blob") + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert "Authorization" not in headers + + @pytest.mark.it( + "Fetches and returns the JSON payload of the HTTP response, if the HTTP request was successful" + ) + async def test_returns_json_payload(self, client): + mock_response = client._session.post.return_value.__aenter__.return_value + assert mock_response.status == 200 + + storage_info = await client.get_storage_info_for_blob(blob_name="fake_blob") + + assert storage_info is mock_response.json.return_value + + @pytest.mark.it( + "Raises an IoTHubError if a HTTP response is received with a failed status code" + ) + @pytest.mark.parametrize("failed_status", failed_status_codes) + async def test_failed_response(self, client, failed_status): + mock_response = client._session.post.return_value.__aenter__.return_value + mock_response.status = failed_status + + with pytest.raises(exc.IoTHubError): + await client.get_storage_info_for_blob(blob_name="fake_blob") + + @pytest.mark.it("Raises IoTHubClientError if not configured as a Device") + @pytest.mark.parametrize( + "edge_module_id", + [ + pytest.param(None, id="Module Configuration"), + pytest.param(FAKE_DEVICE_ID + "/" + FAKE_MODULE_ID, id="Edge Module Configuration"), + ], + ) + async def test_not_device(self, client, edge_module_id): + assert client._device_id is not None + client._module_id = FAKE_MODULE_ID + client._edge_module_id = edge_module_id + + with pytest.raises(exc.IoTHubClientError): + await client.get_storage_info_for_blob(blob_name="some blob") + + @pytest.mark.it("Allows any exceptions raised by the POST request to propagate") + @pytest.mark.parametrize("exception", http_post_exceptions) + async def test_http_post_raises(self, client, exception): + client._session.post.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.get_storage_info_for_blob(blob_name="some blob") + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised while getting the JSON response to propagate") + @pytest.mark.parametrize("exception", http_response_json_exceptions) + async def test_http_response_json_raises(self, client, exception): + mock_response = client._session.post.return_value.__aenter__.return_value + mock_response.json.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.get_storage_info_for_blob(blob_name="some blob") + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the HTTP response") + async def test_cancel_during_request(self, client): + post_ctx_manager = client._session.post.return_value + post_ctx_manager.__aenter__ = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.get_storage_info_for_blob(blob_name="some blob")) + + # Hanging, waiting for response + await post_ctx_manager.__aenter__.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it("Can be cancelled while fetching the payload of the HTTP response") + async def test_cancel_during_payload_fetch(self, client): + response = client._session.post.return_value.__aenter__.return_value + response.json = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.get_storage_info_for_blob(blob_name="some blob")) + + # Hanging, waiting to fetch JSON + await response.json.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubHTTPClient - .notify_blob_upload_status") +class TestIoTHubHTTPClientNotifyBlobUploadStatus: + @pytest.fixture(autouse=True) + def modify_client_config(self, client_config): + """Modify the client config to always be a Device""" + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None + + @pytest.fixture(params=["Notify Upload Success", "Notify Upload Failure"]) + def kwargs(self, request): + """Because there are correlated semantics across a set of given arguments, using a fixture + just makes things easier + """ + if request.param == "Notify Upload Success": + kwargs = { + "correlation_id": FAKE_CORRELATION_ID, + "is_success": True, + "status_code": 200, + "status_description": "Success!", + } + else: + kwargs = { + "correlation_id": FAKE_CORRELATION_ID, + "is_success": False, + "status_code": 500, + "status_description": "Failure!", + } + return kwargs + + @pytest.mark.it( + "Does an asynchronous POST request operation to the relative 'notify blob upload status' path using the aiohttp ClientSession and the stored SSL context" + ) + async def test_http_post(self, mocker, client, kwargs): + post_ctx_manager = client._session.post.return_value + assert client._session.post.call_count == 0 + assert post_ctx_manager.__aenter__.await_count == 0 + expected_path = http_path.get_notify_blob_upload_status_path(client._device_id) + + await client.notify_blob_upload_status(**kwargs) + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=expected_path, + json=mocker.ANY, + params=mocker.ANY, + headers=mocker.ANY, + ssl=client._ssl_context, + ) + assert post_ctx_manager.__aenter__.await_count == 1 + assert post_ctx_manager.__aenter__.await_args == mocker.call() + + @pytest.mark.it("Sends all the provided parameters with the POST request inside a JSON payload") + async def test_post_json(self, mocker, client, kwargs): + assert client._session.post.call_count == 0 + expected_json = { + "correlationId": kwargs["correlation_id"], + "isSuccess": kwargs["is_success"], + "statusCode": kwargs["status_code"], + "statusDescription": kwargs["status_description"], + } + + await client.notify_blob_upload_status(**kwargs) + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=mocker.ANY, + json=expected_json, + params=mocker.ANY, + headers=mocker.ANY, + ssl=mocker.ANY, + ) + + @pytest.mark.it("Sends the API version with the POST request as a query parameter") + async def test_post_query_params(self, mocker, client, kwargs): + assert client._session.post.call_count == 0 + expected_params = {"api-version": constant.IOTHUB_API_VERSION} + + await client.notify_blob_upload_status(**kwargs) + + assert client._session.post.call_count == 1 + assert client._session.post.call_args == mocker.call( + url=mocker.ANY, + json=mocker.ANY, + params=expected_params, + headers=mocker.ANY, + ssl=mocker.ANY, + ) + + @pytest.mark.it( + "Sets the 'User-Agent' HTTP header on the POST request to the URL-encoded `user_agent` value stored on the client" + ) + async def test_post_user_agent_header(self, client, kwargs): + assert client._session.post.call_count == 0 + expected_user_agent = urllib.parse.quote_plus(client._user_agent_string) + assert expected_user_agent != client._user_agent_string + + await client.notify_blob_upload_status(**kwargs) + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert headers["User-Agent"] == expected_user_agent + + @pytest.mark.it( + "Sets the 'Authorization' HTTP header on the POST request to the current SAS Token string from the SasTokenProvider stored on the client, if it exists" + ) + async def test_post_authorization_header_sas(self, client, kwargs, mock_sastoken_provider): + assert client._session.post.call_count == 0 + client._sastoken_provider = mock_sastoken_provider + assert mock_sastoken_provider.get_current_sastoken.call_count == 0 + expected_sastoken_string = str(mock_sastoken_provider.get_current_sastoken.return_value) + + await client.notify_blob_upload_status(**kwargs) + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert headers["Authorization"] == expected_sastoken_string + + @pytest.mark.it( + "Does not include an 'Authorization' HTTP header on the POST request if not using SAS Token authentication" + ) + async def test_post_authorization_header_no_sas(self, client, kwargs): + assert client._session.post.call_count == 0 + assert client._sastoken_provider is None + + await client.notify_blob_upload_status(**kwargs) + + assert client._session.post.call_count == 1 + headers = client._session.post.call_args[1]["headers"] + assert "Authorization" not in headers + + @pytest.mark.it("Does not return a value") + async def test_return_value(self, client, kwargs): + retval = await client.notify_blob_upload_status(**kwargs) + assert retval is None + + @pytest.mark.it( + "Raises an IoTHubError if a HTTP response is received with a failed status code" + ) + @pytest.mark.parametrize("failed_status", failed_status_codes) + async def test_failed_response(self, client, kwargs, failed_status): + mock_response = client._session.post.return_value.__aenter__.return_value + mock_response.status = failed_status + + with pytest.raises(exc.IoTHubError): + await client.notify_blob_upload_status(**kwargs) + + @pytest.mark.it("Raises IoTHubClientError if not configured as a Device") + @pytest.mark.parametrize( + "edge_module_id", + [ + pytest.param(None, id="Module Configuration"), + pytest.param(FAKE_DEVICE_ID + "/" + FAKE_MODULE_ID, id="Edge Module Configuration"), + ], + ) + async def test_not_device(self, client, kwargs, edge_module_id): + assert client._device_id is not None + client._module_id = FAKE_MODULE_ID + client._edge_module_id = edge_module_id + + with pytest.raises(exc.IoTHubClientError): + await client.notify_blob_upload_status(**kwargs) + + @pytest.mark.it("Allows any exceptions raised by the POST request to propagate") + @pytest.mark.parametrize("exception", http_post_exceptions) + async def test_http_post_raises(self, client, kwargs, exception): + client._session.post.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.notify_blob_upload_status(**kwargs) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the HTTP response") + async def test_cancel_during_request(self, client, kwargs): + post_ctx_manager = client._session.post.return_value + post_ctx_manager.__aenter__ = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.notify_blob_upload_status(**kwargs)) + + # Hanging, waiting for response + await post_ctx_manager.__aenter__.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t diff --git a/tests/unit/test_iothub_mqtt_client.py b/tests/unit/test_iothub_mqtt_client.py new file mode 100644 index 000000000..12d7fadb2 --- /dev/null +++ b/tests/unit/test_iothub_mqtt_client.py @@ -0,0 +1,3753 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import abc +import asyncio +import json +import pytest +import ssl +import sys +import time +import typing +import urllib +from pytest_lazyfixture import lazy_fixture +from dev_utils import custom_mock +from azure.iot.device.iothub_mqtt_client import ( + IoTHubMQTTClient, + DEFAULT_RECONNECT_INTERVAL, +) +from azure.iot.device import config, constant, models, user_agent +from azure.iot.device import exceptions as exc +from azure.iot.device import mqtt_client as mqtt +from azure.iot.device import request_response as rr +from azure.iot.device import mqtt_topic_iothub as mqtt_topic +from azure.iot.device import sastoken as st + + +FAKE_DEVICE_ID = "fake_device_id" +FAKE_MODULE_ID = "fake_module_id" +FAKE_DEVICE_CLIENT_ID = "fake_device_id" +FAKE_MODULE_CLIENT_ID = "fake_device_id/fake_module_id" +FAKE_HOSTNAME = "fake.hostname" +FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" +FAKE_EXPIRY = str(int(time.time()) + 3600) +FAKE_URI = "fake/resource/location" +FAKE_INPUT_NAME = "fake_input" + + +# Parametrizations +# TODO: expand this when we know more about what exceptions get raised from MQTTClient +mqtt_connect_exceptions = [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] +mqtt_disconnect_exceptions = [ + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception") +] +mqtt_publish_exceptions = [ + pytest.param(exc.MQTTError(rc=5), id="MQTTError"), + pytest.param(ValueError(), id="ValueError"), + pytest.param(TypeError(), id="TypeError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] +mqtt_subscribe_exceptions = [ + # NOTE: CancelledError is here because network failure can cancel a subscribe + # without explicit invocation of cancel on the subscribe + pytest.param(exc.MQTTError(rc=5), id="MQTTError"), + pytest.param(asyncio.CancelledError(), id="CancelledError (Not initiated)"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] +mqtt_unsubscribe_exceptions = [ + # NOTE: CancelledError is here because network failure can cancel an unsubscribe + # without explicit invocation of cancel on the unsubscribe + pytest.param(exc.MQTTError(rc=5), id="MQTTError"), + pytest.param(asyncio.CancelledError(), id="CancelledError (Not initiated)"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] + + +# Fixtures + + +@pytest.fixture +def sastoken(): + sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( + resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=FAKE_EXPIRY + ) + return st.SasToken(sastoken_str) + + +@pytest.fixture +def mock_sastoken_provider(mocker, sastoken): + provider = mocker.MagicMock(spec=st.SasTokenProvider) + provider.get_current_sastoken.return_value = sastoken + # Use a HangingAsyncMock so that it isn't constantly returning + provider.wait_for_new_sastoken = custom_mock.HangingAsyncMock() + provider.wait_for_new_sastoken.return_value = sastoken + # NOTE: Technically, this mock just always returns the same SasToken, + # even after an "update", but for the purposes of testing at this level, + # it doesn't matter + return provider + + +@pytest.fixture +def client_config(): + """Defaults to Device Configuration. Required values only. + Customize in test if you need specific options (incl. Module)""" + + client_config = config.IoTHubClientConfig( + device_id=FAKE_DEVICE_ID, hostname=FAKE_HOSTNAME, ssl_context=ssl.SSLContext() + ) + return client_config + + +@pytest.fixture +async def client(mocker, client_config): + client = IoTHubMQTTClient(client_config) + # Mock just the network operations from the MQTTClient, not the whole thing. + # This makes using the generators easier + client._mqtt_client.connect = mocker.AsyncMock() + client._mqtt_client.disconnect = mocker.AsyncMock() + client._mqtt_client.subscribe = mocker.AsyncMock() + client._mqtt_client.unsubscribe = mocker.AsyncMock() + client._mqtt_client.publish = mocker.AsyncMock() + # Also mock other methods relevant to tests + client._mqtt_client.set_credentials = mocker.MagicMock() + client._mqtt_client.is_connected = mocker.MagicMock() + + # NOTE: No need to invoke .start() here, at least, not yet. + return client + + +@pytest.mark.describe("IoTHubMQTTClient -- Instantiation") +class TestIoTHubMQTTClientInstantiation: + # NOTE: As the instantiation is the unit under test here, we shouldn't use the client fixture. + # This means you may need to do a manual mock of the underlying MQTTClient where appropriate. + + @pytest.mark.it( + "Stores the `device_id` and `module_id` values from the IoTHubClientConfig as attributes" + ) + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_simple_ids(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + + client = IoTHubMQTTClient(client_config) + assert client._device_id == client_config.device_id + assert client._module_id == client_config.module_id + + @pytest.mark.it( + "Derives the `client_id` from the `device_id` and `module_id` and stores it as an attribute" + ) + @pytest.mark.parametrize( + "device_id, module_id, expected_client_id", + [ + pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"), + pytest.param( + FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration" + ), + ], + ) + async def test_client_id(self, client_config, device_id, module_id, expected_client_id): + client_config.device_id = device_id + client_config.module_id = module_id + + client = IoTHubMQTTClient(client_config) + assert client._client_id == expected_client_id + + @pytest.mark.it("Derives the `username` and stores the result as an attribute") + @pytest.mark.parametrize( + "device_id, module_id, client_id", + [ + pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"), + pytest.param( + FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration" + ), + ], + ) + @pytest.mark.parametrize( + "product_info", + [ + pytest.param("", id="No Product Info"), + pytest.param("my-product-info", id="Custom Product Info"), + pytest.param("my$product$info", id="Custom Product Info (URL encoding required)"), + pytest.param( + constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1", + id="Digital Twin Product Info", + ), + pytest.param( + constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$ensor;1", + id="Digital Twin Product Info (URL encoding required)", + ), + ], + ) + async def test_username( + self, + client_config, + device_id, + module_id, + client_id, + product_info, + ): + client_config.device_id = device_id + client_config.module_id = module_id + client_config.product_info = product_info + + ua = user_agent.get_iothub_user_agent() + url_encoded_user_agent = urllib.parse.quote(ua, safe="") + # NOTE: This assertion shows the URL encoding was meaningful + assert user_agent != url_encoded_user_agent + + url_encoded_product_info = urllib.parse.quote(product_info, safe="") + # NOTE: We can't really make the same assertion here, because this isn't always meaningful + + # Determine expected username based on config + if product_info.startswith(constant.DIGITAL_TWIN_PREFIX): + expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format( + hostname=client_config.hostname, + client_id=client_id, + api_version=constant.IOTHUB_API_VERSION, + user_agent=url_encoded_user_agent, + digital_twin_prefix=constant.DIGITAL_TWIN_QUERY_HEADER, + custom_product_info=url_encoded_product_info, + ) + else: + expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format( + hostname=client_config.hostname, + client_id=client_id, + api_version=constant.IOTHUB_API_VERSION, + user_agent=url_encoded_user_agent, + custom_product_info=url_encoded_product_info, + ) + + client = IoTHubMQTTClient(client_config) + # The expected username was derived + assert client._username == expected_username + + @pytest.mark.it("Stores the `sastoken_provider` from the IoTHubClientConfig as an attribute") + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="No SasTokenProvider present"), + ], + ) + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_sastoken_provider(self, client_config, sastoken_provider, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + client_config.sastoken_provider = sastoken_provider + + client = IoTHubMQTTClient(client_config) + assert client._sastoken_provider is sastoken_provider + + @pytest.mark.it( + "Creates an MQTTClient instance based on the configuration of IoTHubClientConfig and stores it as an attribute" + ) + @pytest.mark.parametrize( + "device_id, module_id, expected_client_id", + [ + pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"), + pytest.param( + FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration" + ), + ], + ) + @pytest.mark.parametrize( + "websockets, expected_transport, expected_port, expected_ws_path", + [ + pytest.param(True, "websockets", 443, "/$iothub/websocket", id="WebSockets"), + pytest.param(False, "tcp", 8883, None, id="TCP"), + ], + ) + async def test_mqtt_client( + self, + mocker, + client_config, + device_id, + module_id, + expected_client_id, + websockets, + expected_transport, + expected_port, + expected_ws_path, + ): + # Configure the client_config based on params + client_config.device_id = device_id + client_config.module_id = module_id + client_config.websockets = websockets + + # Patch the MQTTClient constructor + mock_constructor = mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + assert mock_constructor.call_count == 0 + + # Create the client under test + client = IoTHubMQTTClient(client_config) + + # Assert that the MQTTClient was constructed as expected + assert mock_constructor.call_count == 1 + assert mock_constructor.call_args == mocker.call( + client_id=expected_client_id, + hostname=client_config.hostname, + port=expected_port, + transport=expected_transport, + keep_alive=client_config.keep_alive, + auto_reconnect=client_config.auto_reconnect, + reconnect_interval=DEFAULT_RECONNECT_INTERVAL, + ssl_context=client_config.ssl_context, + websockets_path=expected_ws_path, + proxy_options=client_config.proxy_options, + ) + assert client._mqtt_client is mock_constructor.return_value + + @pytest.mark.it( + "Adds incoming message filter on the MQTTClient for C2D messages, if using a Device Configuration" + ) + async def test_c2d_filter(self, mocker, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None + expected_topic = mqtt_topic.get_c2d_topic_for_subscribe(FAKE_DEVICE_ID) + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = IoTHubMQTTClient(client_config) + + # NOTE: Multiple filters are added, but not all are covered in this test + assert ( + mocker.call(expected_topic) + in client._mqtt_client.add_incoming_message_filter.call_args_list + ) + + @pytest.mark.it( + "Does not add incoming message filter on the MQTTClient for C2D messages, if using a Module Configuration" + ) + async def test_c2d_message_filter_device(self, mocker, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = FAKE_MODULE_ID + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = IoTHubMQTTClient(client_config) + + # NOTE: It's kind of weird to try and show a method wasn't called with an argument, when + # what that argument would even be can't be created without a module ID in the first place. + # What we do here is check every topic that a filter is added for to ensure none of them + # contain the word "input", which an input message topic would uniquely have + for call in client._mqtt_client.add_incoming_message_filter.call_args_list: + topic = call[0][0] + assert "devicebound" not in topic + + @pytest.mark.it( + "Adds incoming message filter on the MQTTClient for input messages, if using a Module Configuration" + ) + async def test_input_message_filter_module(self, mocker, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = FAKE_MODULE_ID + expected_topic = mqtt_topic.get_input_topic_for_subscribe(FAKE_DEVICE_ID, FAKE_MODULE_ID) + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = IoTHubMQTTClient(client_config) + + # NOTE: Multiple filters are added, but not all are covered in this test + assert ( + mocker.call(expected_topic) + in client._mqtt_client.add_incoming_message_filter.call_args_list + ) + + @pytest.mark.it( + "Does not add incoming message filter on the MQTTClient for input messages, if using a Device Configuration" + ) + async def test_input_message_filter_device(self, mocker, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = IoTHubMQTTClient(client_config) + + # NOTE: It's kind of weird to try and show a method wasn't called with an argument, when + # what that argument would even be can't be created without a module ID in the first place. + # What we do here is check every topic that a filter is added for to ensure none of them + # contain the word "input", which an input message topic would uniquely have + for call in client._mqtt_client.add_incoming_message_filter.call_args_list: + topic = call[0][0] + assert "inputs" not in topic + + @pytest.mark.it("Adds incoming message filter on the MQTTClient for direct method requests") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_direct_method_request_filter(self, mocker, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + expected_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = IoTHubMQTTClient(client_config) + + # NOTE: Multiple filters are added, but not all are covered in this test + assert ( + mocker.call(expected_topic) + in client._mqtt_client.add_incoming_message_filter.call_args_list + ) + + @pytest.mark.it("Adds incoming message filter on the MQTTClient for twin patches") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_twin_patch_filter(self, mocker, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + expected_topic = mqtt_topic.get_twin_patch_topic_for_subscribe() + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = IoTHubMQTTClient(client_config) + + # NOTE: Multiple filters are added, but not all are covered in this test + assert ( + mocker.call(expected_topic) + in client._mqtt_client.add_incoming_message_filter.call_args_list + ) + + @pytest.mark.it("Adds incoming message filter on the MQTTClient for twin responses") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_twin_response_filter(self, mocker, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + expected_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = IoTHubMQTTClient(client_config) + + # NOTE: Multiple filters are added, but not all are covered in this test + assert ( + mocker.call(expected_topic) + in client._mqtt_client.add_incoming_message_filter.call_args_list + ) + + # NOTE: For testing the functionality of this generator, see the corresponding test suite (TestIoTHubMQTTClientIncomingC2DMessages) + @pytest.mark.it( + "Provides an incoming C2D message generator as a read-only property, if using a Device Configuration" + ) + async def test_incoming_c2d_messages_device(self, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None + client = IoTHubMQTTClient(client_config) + assert client.incoming_c2d_messages + assert isinstance(client._incoming_c2d_messages, typing.AsyncGenerator) + + @pytest.mark.it( + "Does not create an incoming C2D message generator, if using a Module Configuration" + ) + async def test_c2d_message_generator_module(self, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = FAKE_MODULE_ID + client = IoTHubMQTTClient(client_config) + assert client._incoming_c2d_messages is None + + # NOTE: For testing the functionality of this generator, see the corresponding test suite (TestIoTHubMQTTClientIncomingInputMessages) + @pytest.mark.it( + "Creates and stores an incoming input message generator as an attribute, if using a Module Configuration" + ) + async def test_input_message_generator_module(self, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = FAKE_MODULE_ID + client = IoTHubMQTTClient(client_config) + assert isinstance(client._incoming_input_messages, typing.AsyncGenerator) + + @pytest.mark.it( + "Does not create an incoming input message generator, if using a Device Configuration" + ) + async def test_input_message_generator_device(self, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None + client = IoTHubMQTTClient(client_config) + assert client._incoming_input_messages is None + + # NOTE: For testing the functionality of this generator, see the corresponding test suite (TestIoTHubMQTTClientIncomingDirectDirectMethodRequests) + @pytest.mark.it( + "Creates and stores an incoming direct method request generator as an attribute" + ) + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_direct_method_request_generator(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + client = IoTHubMQTTClient(client_config) + assert isinstance(client._incoming_direct_method_requests, typing.AsyncGenerator) + + # NOTE: For testing the functionality of this generator, see the corresponding test suite (TestIoTHubMQTTClientIncomingTwinPatches) + @pytest.mark.it("Creates and stores an incoming twin patch generator as an attribute") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_twin_patch_generator(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + client = IoTHubMQTTClient(client_config) + assert isinstance(client._incoming_twin_patches, typing.AsyncGenerator) + + @pytest.mark.it("Creates an empty RequestLedger") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_request_ledger(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + client = IoTHubMQTTClient(client_config) + assert isinstance(client._request_ledger, rr.RequestLedger) + assert len(client._request_ledger) == 0 + + @pytest.mark.it("Sets the twin_responses_enabled flag to False") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_twin_responses_enabled(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + client = IoTHubMQTTClient(client_config) + assert client._twin_responses_enabled is False + + @pytest.mark.it("Sets background task attributes to None") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_bg_tasks(self, client_config, device_id, module_id): + client_config.device_id = device_id + client_config.module_id = module_id + client = IoTHubMQTTClient(client_config) + assert client._process_twin_responses_bg_task is None + + +@pytest.mark.describe("IoTHubMQTTClient - .start()") +class TestIoTHubMQTTClientStart: + @pytest.mark.it( + "Sets the credentials on the MQTTClient, using the stored `username` as the username and no password, when not using SAS authentication" + ) + async def test_mqtt_client_credentials_no_sas(self, mocker, client): + assert client._sastoken_provider is None + assert client._mqtt_client.set_credentials.call_count == 0 + + await client.start() + + assert client._mqtt_client.set_credentials.call_count == 1 + assert client._mqtt_client.set_credentials.call_args == mocker.call(client._username, None) + + # Cleanup + await client.stop() + + @pytest.mark.it( + "Sets the credentials on the MQTTClient, using the stored `username` as the username and the string-converted current SasToken from the SasTokenProvider as the password, when using SAS authentication" + ) + async def test_mqtt_client_credentials_with_sas(self, mocker, client, mock_sastoken_provider): + client._sastoken_provider = mock_sastoken_provider + fake_sastoken = mock_sastoken_provider.get_current_sastoken.return_value + assert client._mqtt_client.set_credentials.call_count == 0 + + await client.start() + + assert client._mqtt_client.set_credentials.call_count == 1 + assert client._mqtt_client.set_credentials.call_args(client._username, str(fake_sastoken)) + + await client.stop() + + # NOTE: For testing the functionality of this task, see the corresponding test suite (TestIoTHubMQTTClientIncomingTwinResponse) + @pytest.mark.it( + "Begins running the ._process_twin_responses() coroutine method as a background task, storing it as an attribute" + ) + async def test_process_twin_responses_bg_task(self, client): + assert client._process_twin_responses_bg_task is None + + await client.start() + + assert isinstance(client._process_twin_responses_bg_task, asyncio.Task) + assert not client._process_twin_responses_bg_task.done() + if sys.version_info > (3, 8): + # NOTE: There isn't a way to validate the contents of a task until 3.8 + # as far as I can tell. + task_coro = client._process_twin_responses_bg_task.get_coro() + assert task_coro.__qualname__ == "IoTHubMQTTClient._process_twin_responses" + + # Cleanup + await client.stop() + + @pytest.mark.it( + "Does not alter any background tasks if already started, but does reset the credentials with the same values" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SAS Auth"), + pytest.param(None, id="No SAS auth"), + ], + ) + async def test_already_started(self, client, sastoken_provider): + client._sastoken_provider = sastoken_provider + assert client._mqtt_client.set_credentials.call_count == 0 + + # Start + await client.start() + + # Current tasks + current_process_twin_responses_task = client._process_twin_responses_bg_task + # Credentials set + assert client._mqtt_client.set_credentials.call_count == 1 + credential_args = client._mqtt_client.set_credentials.call_args + + # Start again + await client.start() + + # Tasks unchanged + assert client._process_twin_responses_bg_task is current_process_twin_responses_task + # Credentials set again (the same values as before) + assert client._mqtt_client.set_credentials.call_count == 2 + assert client._mqtt_client.set_credentials.call_args == credential_args + + # Cleanup + await client.stop() + + +@pytest.mark.describe("IoTHubMQTTClient - .stop()") +class TestIoTHubMQTTClientStop: + @pytest.fixture(autouse=True) + async def modify_client(self, client, mock_sastoken_provider): + client._sastoken_provider = mock_sastoken_provider + # Need to start the client so we can stop it. + await client.start() + + @pytest.mark.it("Disconnects the MQTTClient") + async def test_disconnect(self, mocker, client): + # NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the + # IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume + # correctness, lest we have to repeat all .disconnect() tests here. + original_disconnect = client.disconnect + client.disconnect = mocker.AsyncMock() + try: + assert client.disconnect.await_count == 0 + + await client.stop() + + assert client.disconnect.await_count == 1 + assert client.disconnect.await_args == mocker.call() + finally: + client.disconnect = original_disconnect + + @pytest.mark.it( + "Cancels the 'process_twin_responses' background task and removes it, if it exists" + ) + async def test_process_twin_responses_bg_task(self, client): + assert isinstance(client._process_twin_responses_bg_task, asyncio.Task) + t = client._process_twin_responses_bg_task + assert not t.done() + + await client.stop() + + assert t.done() + assert t.cancelled() + assert client._process_twin_responses_bg_task is None + + # NOTE: Currently this is an invalid scenario. This shouldn't happen, but test it anyway. + @pytest.mark.it("Handles the case where no 'process_twin_responses' background task exists") + async def test_process_twin_responses_bg_task_no_exist(self, client): + # The task is already running, so cancel and remove it + assert isinstance(client._process_twin_responses_bg_task, asyncio.Task) + client._process_twin_responses_bg_task.cancel() + client._process_twin_responses_bg_task = None + + await client.stop() + # No AttributeError means success! + + @pytest.mark.it( + "Allows any exception raised during MQTTClient disconnect to propagate, but only after cancelling background tasks" + ) + @pytest.mark.parametrize("exception", mqtt_disconnect_exceptions) + async def test_disconnect_raises(self, mocker, client, exception): + # NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the + # IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume + # correctness, lest we have to repeat all .disconnect() tests here. + original_disconnect = client.disconnect + client.disconnect = mocker.AsyncMock(side_effect=exception) + try: + process_twin_responses_bg_task = client._process_twin_responses_bg_task + assert not process_twin_responses_bg_task.done() + + with pytest.raises(type(exception)) as e_info: + await client.stop() + assert e_info.value is exception + + # Background tasks were also cancelled despite the exception + assert process_twin_responses_bg_task.done() + assert process_twin_responses_bg_task.cancelled() + # And they were unset too + assert client._process_twin_responses_bg_task is None + finally: + # Unset the mock so that tests can clean up + client.disconnect = original_disconnect + + # TODO: when run by itself, this test leaves a task unresolved. Not sure why. Not too important. + @pytest.mark.it( + "Does not alter any background tasks if already stopped, but does disconnect again" + ) + async def test_already_stopped(self, mocker, client): + original_disconnect = client.disconnect + client.disconnect = mocker.AsyncMock() + try: + assert client.disconnect.await_count == 0 + + # Stop + await client.stop() + assert client._process_twin_responses_bg_task is None + assert client.disconnect.await_count == 1 + + # Stop again + await client.stop() + assert client._process_twin_responses_bg_task is None + assert client.disconnect.await_count == 2 + + finally: + client.disconnect = original_disconnect + + # TODO: when run by itself, this test leaves a task unresolved. Not sure why. Not too important. + @pytest.mark.it( + "Can be cancelled while waiting for the MQTTClient disconnect to finish, but it won't stop background task cancellation" + ) + async def test_cancel_disconnect(self, client): + # NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the + # IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume + # correctness, lest we have to repeat all .disconnect() tests here. + original_disconnect = client.disconnect + client.disconnect = custom_mock.HangingAsyncMock() + try: + process_twin_responses_bg_task = client._process_twin_responses_bg_task + assert not process_twin_responses_bg_task.done() + + t = asyncio.create_task(client.stop()) + + # Hanging, waiting for disconnect to finish + await client.disconnect.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + # Due to cancellation, the tasks we want to assert are done may need a moment + # to finish, since we aren't waiting on them to exit + await asyncio.sleep(0.1) + + # And yet the background tasks still were cancelled anyway + assert process_twin_responses_bg_task.done() + assert process_twin_responses_bg_task.cancelled() + # And they were unset too + assert client._process_twin_responses_bg_task is None + finally: + # Unset the mock so that tests can clean up. + client.disconnect = original_disconnect + + @pytest.mark.it( + "Can be cancelled while waiting for the background tasks to finish cancellation, but it won't stop the background task cancellation" + ) + async def test_cancel_gather(self, mocker, client): + original_gather = asyncio.gather + asyncio.gather = custom_mock.HangingAsyncMock() + spy_twin_response_bg_task_cancel = mocker.spy( + client._process_twin_responses_bg_task, "cancel" + ) + try: + process_twin_responses_bg_task = client._process_twin_responses_bg_task + assert not process_twin_responses_bg_task.done() + + t = asyncio.create_task(client.stop()) + + # Hanging waiting for gather to return (indicating tasks are all done cancellation) + await asyncio.gather.wait_for_hang() + assert not t.done() + # Background tests may or may not have completed cancellation yet, hard to test accurately. + # But their cancellation HAS been requested. + assert spy_twin_response_bg_task_cancel.call_count == 1 + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # Tasks will be cancelled very soon (if they aren't already) + await asyncio.sleep(0.1) + assert process_twin_responses_bg_task.done() + assert process_twin_responses_bg_task.cancelled() + # And they were unset too + assert client._process_twin_responses_bg_task is None + finally: + # Unset the mock so that tests can clean up. + asyncio.gather = original_gather + + +@pytest.mark.describe("IoTHubMQTTClient - .connect()") +class TestIoTHubMQTTClientConnect: + @pytest.mark.it("Awaits a connect using the MQTTClient") + async def test_mqtt_connect(self, mocker, client): + assert client._mqtt_client.connect.await_count == 0 + + await client.connect() + + assert client._mqtt_client.connect.await_count == 1 + assert client._mqtt_client.connect.await_args == mocker.call() + + @pytest.mark.it("Allows any exceptions raised during the MQTTClient connect to propagate") + @pytest.mark.parametrize("exception", mqtt_connect_exceptions) + async def test_mqtt_exception(self, client, exception): + client._mqtt_client.connect.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.connect() + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient connect to finish") + async def test_cancel(self, client): + client._mqtt_client.connect = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.connect()) + + # Hanging, waiting for MQTT connect to finish + await client._mqtt_client.connect.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubMQTTClient - .disconnect()") +class TestIoTHubMQTTClientDisconnect: + @pytest.mark.it("Awaits a disconnect using the MQTTClient") + async def test_mqtt_disconnect(self, mocker, client): + assert client._mqtt_client.disconnect.await_count == 0 + + await client.disconnect() + + assert client._mqtt_client.disconnect.await_count == 1 + assert client._mqtt_client.disconnect.await_args == mocker.call() + + @pytest.mark.it("Allows any exceptions raised during the MQTTClient disconnect to propagate") + @pytest.mark.parametrize("exception", mqtt_disconnect_exceptions) + async def test_mqtt_exception(self, client, exception): + client._mqtt_client.disconnect.side_effect = exception + try: + with pytest.raises(type(exception)) as e_info: + await client.disconnect() + assert e_info.value is exception + finally: + # Unset the side effect for cleanup (since shutdown uses disconnect) + client._mqtt_client.disconnect.side_effect = None + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient disconnect to finish") + async def test_cancel(self, mocker, client): + client._mqtt_client.disconnect = custom_mock.HangingAsyncMock() + try: + t = asyncio.create_task(client.disconnect()) + + # Hanging, waiting for MQTT disconnect to finish + await client._mqtt_client.disconnect.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + finally: + # Unset the HangingMock for clean (since shutdown uses disconnect) + client._mqtt_client.disconnect = mocker.AsyncMock() + + +@pytest.mark.describe("IoTHubMQTTClient - .wait_for_disconnect()") +class TestIoTHubMQTTClientReportConnectionDrop: + @pytest.mark.it( + "Returns None if an expected disconnect has previously ocurred in the MQTTClient" + ) + async def test_previous_expected_disconnect(self, client): + # Simulate expected disconnect + client._mqtt_client._disconnection_cause = None + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Reports no cause (i.e. expected disconnect) + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is None + + @pytest.mark.it( + "Waits for a disconnect to occur in the MQTTClient, and returns None once an expected disconnect occurs, if no disconnect has yet ocurred" + ) + async def test_expected_disconnect(self, client): + # No connection drop to report + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert not t.done() + + # Simulate expected disconnect + client._mqtt_client._disconnection_cause = None + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Report no cause (i.e. expected disconnect) + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is None + + @pytest.mark.it( + "Returns the MQTTConnectionDroppedError that caused an unexpected disconnect in the MQTTClient, if an unexpected disconnect has already occurred" + ) + async def test_previous_unexpected_disconnect(self, client): + # Simulate unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + client._mqtt_client._disconnection_cause = cause + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Reports the cause that is already available + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is cause + + @pytest.mark.it( + "Waits for a disconnect to occur in the MQTTClient, and returns the MQTTError that caused it once an unexpected disconnect occurs, if no disconnect has not yet ocurred" + ) + async def test_unexpected_disconnect(self, client): + # No connection drop to report yet + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert not t.done() + + # Simulate unexpected disconnect + cause = exc.MQTTError(rc=7) + client._mqtt_client._disconnection_cause = cause + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Cause can now be reported + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is cause + + +@pytest.mark.describe("IoTHubMQTTClient - .send_message()") +class TestIoTHubMQTTClientSendMessage: + @pytest.fixture + def message(self): + return models.Message("some payload") + + @pytest.mark.it( + "Awaits a publish to the telemetry topic using the MQTTClient, sending the given Message's payload converted to bytes" + ) + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_mqtt_publish(self, mocker, client, device_id, module_id): + assert client._mqtt_client.publish.await_count == 0 + client._device_id = device_id + client._module_id = module_id + message = models.Message(payload="some_payload") + base_topic = mqtt_topic.get_telemetry_topic_for_publish(device_id, module_id) + expected_topic = mqtt_topic.insert_message_properties_in_topic( + topic=base_topic, + system_properties=message.get_system_properties_dict(), + custom_properties=message.custom_properties, + ) + expected_payload = message.payload.encode("utf-8") + + assert client._mqtt_client.publish.await_count == 0 + await client.send_message(message) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_payload + ) + assert isinstance(expected_payload, bytes) + + @pytest.mark.it( + "Derives the byte payload from the Message payload according to the Message's content encoding and content type properties" + ) + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + @pytest.mark.parametrize("content_encoding", ["utf-8", "utf-16", "utf-32"]) + @pytest.mark.parametrize( + "content_type, payload, expected_str_payload", + [ + pytest.param("text/plain", "some_text", "some_text", id="text/plain"), + pytest.param( + "application/json", {"some": "json"}, '{"some": "json"}', id="application/json" + ), + ], + ) + async def test_publish_payload( + self, + mocker, + client, + device_id, + module_id, + content_encoding, + content_type, + payload, + expected_str_payload, + ): + client._device_id = device_id + client._module_id = module_id + message = models.Message( + payload=payload, content_encoding=content_encoding, content_type=content_type + ) + base_topic = mqtt_topic.get_telemetry_topic_for_publish(device_id, module_id) + expected_topic = mqtt_topic.insert_message_properties_in_topic( + topic=base_topic, + system_properties=message.get_system_properties_dict(), + custom_properties=message.custom_properties, + ) + expected_byte_payload = expected_str_payload.encode(content_encoding) + + await client.send_message(message) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_byte_payload + ) + + @pytest.mark.it("Supports any string-convertible payload when using text/plain content type") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + @pytest.mark.parametrize( + "payload", + [ + pytest.param("String Payload", id="String Payload"), + pytest.param(1234, id="Int Payload"), + pytest.param(2.0, id="Float Payload"), + pytest.param(True, id="Boolean Payload"), + pytest.param([1, 2, 3], id="List Payload"), + pytest.param({"some": {"dictionary": "value"}}, id="Dictionary Payload"), + pytest.param((1, 2), id="Tuple Payload"), + pytest.param(None, id="No Payload"), + ], + ) + async def test_text_plain_payload(self, mocker, client, device_id, module_id, payload): + client._device_id = device_id + client._module_id = module_id + message = models.Message(payload=payload, content_type="text/plain") + base_topic = mqtt_topic.get_telemetry_topic_for_publish(device_id, module_id) + expected_topic = mqtt_topic.insert_message_properties_in_topic( + topic=base_topic, + system_properties=message.get_system_properties_dict(), + custom_properties=message.custom_properties, + ) + expected_byte_payload = str(message.payload).encode("utf-8") + + await client.send_message(message) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_byte_payload + ) + + @pytest.mark.it( + "Supports any JSON-serializable payload when using application/json content type" + ) + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + @pytest.mark.parametrize( + "payload", + [ + pytest.param("String Payload", id="String Payload"), + pytest.param(1234, id="Int Payload"), + pytest.param(2.0, id="Float Payload"), + pytest.param(True, id="Boolean Payload"), + pytest.param([1, 2, 3], id="List Payload"), + pytest.param({"some": {"dictionary": "value"}}, id="Dictionary Payload"), + pytest.param((1, 2), id="Tuple Payload"), + pytest.param(None, id="No Payload"), + ], + ) + async def test_application_json_payload(self, mocker, client, device_id, module_id, payload): + client._device_id = device_id + client._module_id = module_id + message = models.Message(payload=payload, content_type="application/json") + base_topic = mqtt_topic.get_telemetry_topic_for_publish(device_id, module_id) + expected_topic = mqtt_topic.insert_message_properties_in_topic( + topic=base_topic, + system_properties=message.get_system_properties_dict(), + custom_properties=message.custom_properties, + ) + expected_byte_payload = json.dumps(message.payload).encode("utf-8") + + await client.send_message(message) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_byte_payload + ) + + @pytest.mark.it("Inserts any Message properties in the telemetry topic") + @pytest.mark.parametrize( + "device_id, module_id", + [ + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), + ], + ) + async def test_message_properties(self, mocker, client, device_id, module_id, message): + assert client._mqtt_client.publish.await_count == 0 + client._device_id = device_id + client._module_id = module_id + + message.message_id = "some message id" + message.content_encoding = "utf-8" + message.content_type = "text/plain" + message.output_name = "some output" + message.custom_properties["custom_property1"] = 123 + message.custom_properties["custom_property2"] = "456" + message.set_as_security_message() + base_topic = mqtt_topic.get_telemetry_topic_for_publish(device_id, module_id) + expected_topic = mqtt_topic.insert_message_properties_in_topic( + topic=base_topic, + system_properties=message.get_system_properties_dict(), + custom_properties=message.custom_properties, + ) + + assert "%24.mid" in expected_topic # message_id + assert "%24.ce" in expected_topic # content_encoding + assert "%24.ct" in expected_topic # content_type + assert "%24.on" in expected_topic # output_name + assert "%24.ifid" in expected_topic # security message indicator + assert "custom_property1" in expected_topic # custom property + assert "custom_property2" in expected_topic # custom property + + await client.send_message(message) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call(expected_topic, mocker.ANY) + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient publish to propagate") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_exception(self, client, exception, message): + client._mqtt_client.publish.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.send_message(message) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient publish to finish") + async def test_cancel(self, client, message): + client._mqtt_client.publish = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.send_message(message)) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.publish.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubMQTTClient - .send_direct_method_response()") +class TestIoTHubMQTTClientSendDirectMethodResponse: + @pytest.fixture + def method_response(self): + json_response = {"some": {"json": "payload"}} + method_response = models.DirectMethodResponse( + request_id="123", status=200, payload=json_response + ) + return method_response + + @pytest.mark.it( + "Awaits a publish to the direct method response topic using the MQTTClient, sending the given DirectMethodResponse's JSON payload converted to string" + ) + async def test_mqtt_publish(self, mocker, client, method_response): + assert client._mqtt_client.publish.await_count == 0 + + expected_topic = mqtt_topic.get_direct_method_response_topic_for_publish( + method_response.request_id, method_response.status + ) + expected_payload = json.dumps(method_response.payload) + + await client.send_direct_method_response(method_response) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_payload + ) + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient publish to propagate") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_exception(self, client, method_response, exception): + client._mqtt_client.publish.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.send_direct_method_response(method_response) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient publish to finish") + async def test_cancel(self, client, method_response): + client._mqtt_client.publish = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.send_direct_method_response(method_response)) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.publish.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubMQTTClient - .send_twin_patch()") +class TestIoTHubMQTTClientSendTwinPatch: + @pytest.fixture + def twin_patch(self): + return {"property": "updated_value"} + + @pytest.fixture(autouse=True) + def modify_publish(self, client): + # Add a side effect to publish that will complete the pending request for that request id. + # This will allow most tests to be able to ignore request/response infrastructure mocks. + # If this is not the desired behavior (such as in tests OF the request/response paradigm) + # override the publish behavior. + # + # To see tests regarding how this actually works in practice, see the relevant test suite + async def fake_publish(topic, payload): + rid = topic[topic.rfind("$rid=") :].split("=")[1] + response = rr.Response(rid, 200, "body") + await client._request_ledger.match_response(response) + + client._mqtt_client.publish.side_effect = fake_publish + + @pytest.mark.it( + "Awaits a subscribe to the twin response topic using the MQTTClient, if twin responses have not already been enabled" + ) + async def test_mqtt_subscribe_not_enabled(self, mocker, client, twin_patch): + assert client._mqtt_client.subscribe.await_count == 0 + assert client._twin_responses_enabled is False + expected_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + + await client.send_twin_patch(twin_patch) + + assert client._mqtt_client.subscribe.await_count == 1 + assert client._mqtt_client.subscribe.await_args == mocker.call(expected_topic) + + @pytest.mark.it("Does not perform a subscribe if twin responses have already been enabled") + async def test_mqtt_subscribe_already_enabled(self, client, twin_patch): + assert client._mqtt_client.subscribe.await_count == 0 + client._twin_responses_enabled = True + + await client.send_twin_patch(twin_patch) + + assert client._mqtt_client.subscribe.call_count == 0 + + @pytest.mark.it("Sets the twin_response_enabled flag to True upon subscribe success") + async def test_response_enabled_flag_success(self, client, twin_patch): + assert client._twin_responses_enabled is False + + await client.send_twin_patch(twin_patch) + + assert client._twin_responses_enabled is True + + @pytest.mark.it("Generates a new Request, using the RequestLedger stored on the client") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + async def test_generate_request(self, mocker, client, twin_patch, responses_enabled): + client._twin_responses_enabled = responses_enabled + spy_create_request = mocker.spy(client._request_ledger, "create_request") + + await client.send_twin_patch(twin_patch) + + assert spy_create_request.await_count == 1 + + @pytest.mark.it( + "Awaits a publish to the twin patch topic using the MQTTClient, sending the given twin patch JSON converted to string" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + async def test_mqtt_publish(self, mocker, client, twin_patch, responses_enabled): + client._twin_responses_enabled = responses_enabled + spy_create_request = mocker.spy(client._request_ledger, "create_request") + assert client._mqtt_client.publish.await_count == 0 + + await client.send_twin_patch(twin_patch) + + request = spy_create_request.spy_return + expected_topic = mqtt_topic.get_twin_patch_topic_for_publish(request.request_id) + expected_payload = json.dumps(twin_patch) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_payload + ) + + @pytest.mark.it("Awaits a received Response to the Request") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + async def test_get_response(self, mocker, client, twin_patch, responses_enabled): + client._twin_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = "fake_request_id" # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + mock_request.get_response.return_value = mock_response + + await client.send_twin_patch(twin_patch) + + assert mock_request.get_response.await_count == 1 + assert mock_request.get_response.await_args == mocker.call() + + @pytest.mark.it("Returns None if a successful status is received via the Response") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + @pytest.mark.parametrize( + "successful_status", + [ + pytest.param(200, id="Status Code: 200"), + pytest.param(204, id="Status Code: 204"), + ], + ) + async def test_success_response( + self, mocker, client, twin_patch, responses_enabled, successful_status + ): + client._twin_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = "fake_request_id" # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = successful_status + mock_request.get_response.return_value = mock_response + + result = await client.send_twin_patch(twin_patch) + assert result is None + + @pytest.mark.it("Raises an IoTHubError if an unsuccessful status is received via the Response") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + @pytest.mark.parametrize( + "failed_status", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + async def test_failed_response( + self, mocker, client, twin_patch, responses_enabled, failed_status + ): + client._twin_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = "fake_request_id" # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = failed_status + mock_request.get_response.return_value = mock_response + + with pytest.raises(exc.IoTHubError): + await client.send_twin_patch(twin_patch) + + # NOTE: MQTTClient subscribe can generate it's own cancellations due to network failure. + # This is different from a user-initiated cancellation + @pytest.mark.it("Allows any exceptions raised from the MQTTClient subscribe to propagate") + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_mqtt_subscribe_exception(self, client, twin_patch, exception): + assert client._twin_responses_enabled is False + client._mqtt_client.subscribe.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.send_twin_patch(twin_patch) + assert e_info.value is exception + + # NOTE: MQTTClient subscribe can generate it's own cancellations due to network failure. + # This is different from a user-initiated cancellation + @pytest.mark.it( + "Does not set the twin_response_enabled flag to True or create a Request if MQTTClient subscribe raises" + ) + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_subscribe_exception_cleanup(self, mocker, client, twin_patch, exception): + assert client._twin_responses_enabled is False + spy_create_request = mocker.spy(client._request_ledger, "create_request") + client._mqtt_client.subscribe.side_effect = exception + + with pytest.raises(type(exception)): + await client.send_twin_patch(twin_patch) + + assert client._twin_responses_enabled is False + assert spy_create_request.await_count == 0 + + # NOTE: This is a user invoked cancel, as opposed to one above, which was generated by the + # MQTTClient in response to a network failure. + @pytest.mark.it( + "Does not set the twin_response_enabled flag to True or create a Request if cancelled while waiting for the MQTT subscribe to finish" + ) + async def test_mqtt_subscribe_cancel_cleanup(self, mocker, client, twin_patch): + assert client._twin_responses_enabled is False + spy_create_request = mocker.spy(client._request_ledger, "create_request") + client._mqtt_client.subscribe = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.send_twin_patch(twin_patch)) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.subscribe.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + assert client._twin_responses_enabled is False + assert spy_create_request.await_count == 0 + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient publish to propagate") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception(self, client, twin_patch, exception): + client._mqtt_client.publish.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.send_twin_patch(twin_patch) + assert e_info.value is exception + + @pytest.mark.it("Deletes the Request from the RequestLedger if MQTTClient publish raises") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception_cleanup(self, mocker, client, twin_patch, exception): + client._mqtt_client.publish.side_effect = exception + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + with pytest.raises(type(exception)): + await client.send_twin_patch(twin_patch) + + # The Request that was created was also deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call() + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for the MQTTClient publish to finish" + ) + async def test_mqtt_publish_cancel_cleanup(self, mocker, client, twin_patch): + client._mqtt_client.publish = custom_mock.HangingAsyncMock() + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + t = asyncio.create_task(client.send_twin_patch(twin_patch)) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.publish.wait_for_hang() + assert not t.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call() + assert spy_delete_request.await_count == 0 + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # The Request that was created has now been deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for a twin response" + ) + async def test_waiting_response_cancel_cleanup(self, mocker, client, twin_patch): + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + + # Mock Request creation to return a specific, mocked request that hangs on + # awaiting a Response + request = rr.Request() + request.get_response = custom_mock.HangingAsyncMock() + mocker.patch.object(rr, "Request", return_value=request) + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + send_task = asyncio.create_task(client.send_twin_patch(twin_patch)) + + # Hanging, waiting for response + await request.get_response.wait_for_hang() + assert not send_task.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call() + assert spy_delete_request.await_count == 0 + + # Cancel + send_task.cancel() + with pytest.raises(asyncio.CancelledError): + await send_task + + # The Request that was created was also deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call(request.request_id) + + +@pytest.mark.describe("IoTHubMQTTClient - .get_twin()") +class TestIoTHubMQTTClientGetTwin: + @pytest.fixture(autouse=True) + def modify_publish(self, client): + # Add a side effect to publish that will complete the pending request for that request id. + # This will allow most tests to be able to ignore request/response infrastructure mocks. + # If this is not the desired behavior (such as in tests OF the request/response paradigm) + # override the publish behavior. + # + # To see tests regarding how this actually works in practice, see the relevant test suite + async def fake_publish(topic, payload): + rid = topic[topic.rfind("$rid=") :].split("=")[1] + response = rr.Response(rid, 200, '{"json": "in", "a": {"string": "format"}}') + await client._request_ledger.match_response(response) + + client._mqtt_client.publish.side_effect = fake_publish + + @pytest.mark.it( + "Awaits a subscribe to the twin response topic using the MQTTClient, if twin responses have not already been enabled" + ) + async def test_mqtt_subscribe_not_enabled(self, mocker, client): + assert client._mqtt_client.subscribe.await_count == 0 + assert client._twin_responses_enabled is False + expected_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + + await client.get_twin() + + assert client._mqtt_client.subscribe.await_count == 1 + assert client._mqtt_client.subscribe.await_args == mocker.call(expected_topic) + + @pytest.mark.it("Does not perform a subscribe if twin responses have already been enabled") + async def test_mqtt_subscribe_already_enabled(self, client): + assert client._mqtt_client.subscribe.await_count == 0 + client._twin_responses_enabled = True + + await client.get_twin() + + assert client._mqtt_client.subscribe.call_count == 0 + + @pytest.mark.it("Sets the twin_response_enabled flag to True upon subscribe success") + async def test_response_enabled_flag_success(self, client): + assert client._twin_responses_enabled is False + + await client.get_twin() + + assert client._twin_responses_enabled is True + + @pytest.mark.it("Generates a new Request, using the RequestLedger stored on the client") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + async def test_generate_request(self, mocker, client, responses_enabled): + client._twin_responses_enabled = responses_enabled + spy_create_request = mocker.spy(client._request_ledger, "create_request") + + await client.get_twin() + + assert spy_create_request.await_count == 1 + + @pytest.mark.it("Awaits a publish to the twin request topic using the MQTTClient") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + async def test_mqtt_publish(self, mocker, client, responses_enabled): + client._twin_responses_enabled = responses_enabled + spy_create_request = mocker.spy(client._request_ledger, "create_request") + assert client._mqtt_client.publish.await_count == 0 + + await client.get_twin() + + request = spy_create_request.spy_return + expected_topic = mqtt_topic.get_twin_request_topic_for_publish(request.request_id) + expected_payload = " " + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_payload + ) + + @pytest.mark.it("Awaits a received Response to the Request") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + async def test_get_response(self, mocker, client, responses_enabled): + client._twin_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = "fake_request_id" # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + mock_response.body = '{"json": "in", "a": {"string": "format"}}' + mock_request.get_response.return_value = mock_response + + await client.get_twin() + + assert mock_request.get_response.await_count == 1 + assert mock_request.get_response.await_args == mocker.call() + + @pytest.mark.it("Raises an IoTHubError if an unsuccessful status is received via the Response") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + @pytest.mark.parametrize( + "failed_status", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + async def test_failed_response(self, mocker, client, responses_enabled, failed_status): + client._twin_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = "fake_request_id" # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = failed_status + mock_response.body = " " + mock_request.get_response.return_value = mock_response + + with pytest.raises(exc.IoTHubError): + await client.get_twin() + + @pytest.mark.it( + "Returns the twin received in the Response, converted to JSON, if the Response status was successful" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Twin Responses Already Enabled"), + pytest.param(False, id="Twin Responses Not Yet Enabled"), + ], + ) + @pytest.mark.parametrize( + "successful_status", + [ + pytest.param(200, id="Status Code: 200"), + pytest.param(204, id="Status Code: 204"), + ], + ) + async def test_success_response(self, mocker, client, responses_enabled, successful_status): + client._twin_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = "fake_request_id" # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = successful_status + fake_twin_string = '{"json": "in", "a": {"string": "format"}}' + mock_response.body = fake_twin_string + mock_request.get_response.return_value = mock_response + + twin = await client.get_twin() + assert twin == json.loads(fake_twin_string) + + # NOTE: MQTTClient subscribe can generate it's own cancellations due to network failure. + # This is different from a user-initiated cancellation + @pytest.mark.it("Allows any exceptions raised from the MQTTClient subscribe to propagate") + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_mqtt_subscribe_exception(self, client, exception): + assert client._twin_responses_enabled is False + client._mqtt_client.subscribe.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.get_twin() + assert e_info.value is exception + + # NOTE: MQTTClient subscribe can generate it's own cancellations due to network failure. + # This is different from a user-initiated cancellation + @pytest.mark.it( + "Does not set the twin_response_enabled flag to True or create a Request if MQTTClient subscribe raises" + ) + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_subscribe_exception_cleanup(self, mocker, client, exception): + assert client._twin_responses_enabled is False + spy_create_request = mocker.spy(client._request_ledger, "create_request") + client._mqtt_client.subscribe.side_effect = exception + + with pytest.raises(type(exception)): + await client.get_twin() + + assert client._twin_responses_enabled is False + assert spy_create_request.await_count == 0 + + # NOTE: This is a user invoked cancel, as opposed to one above, which was generated by the + # MQTTClient in response to a network failure. + @pytest.mark.it( + "Does not set the twin_response_enabled flag to True or create a Request if cancelled while waiting for the MQTTClient subscribe to finish" + ) + async def test_mqtt_subscribe_cancel_cleanup(self, mocker, client): + assert client._twin_responses_enabled is False + spy_create_request = mocker.spy(client._request_ledger, "create_request") + client._mqtt_client.subscribe = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.get_twin()) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.subscribe.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + assert client._twin_responses_enabled is False + assert spy_create_request.await_count == 0 + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient publish to propagate") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception(self, client, exception): + client._mqtt_client.publish.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.get_twin() + assert e_info.value is exception + + @pytest.mark.it("Deletes the Request from the RequestLedger if MQTTClient publish raises") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception_cleanup(self, mocker, client, exception): + client._mqtt_client.publish.side_effect = exception + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + with pytest.raises(type(exception)): + await client.get_twin() + + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call() + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for the MQTTClient publish to finish" + ) + async def test_mqtt_publish_cancel_cleanup(self, mocker, client): + client._mqtt_client.publish = custom_mock.HangingAsyncMock() + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + t = asyncio.create_task(client.get_twin()) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.publish.wait_for_hang() + assert not t.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call() + assert spy_delete_request.await_count == 0 + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # The Request that was created has now been deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for a twin response" + ) + async def test_waiting_response_cancel_cleanup(self, mocker, client): + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + + # Mock Request creation to return a specific, mocked request that hangs on + # awaiting a Response + request = rr.Request() + request.get_response = custom_mock.HangingAsyncMock() + mocker.patch.object(rr, "Request", return_value=request) + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + send_task = asyncio.create_task(client.get_twin()) + + # Hanging, waiting for response + await request.get_response.wait_for_hang() + assert not send_task.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call() + assert spy_delete_request.await_count == 0 + + # Cancel + send_task.cancel() + with pytest.raises(asyncio.CancelledError): + await send_task + + # The Request that was created was also deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call(request.request_id) + + +class IoTHubMQTTClientEnableReceiveTest(abc.ABC): + """Base class for the .enable_x() methods""" + + @abc.abstractmethod + @pytest.fixture + def method_name(self): + """Return the name of the enable method under test""" + pass + + @abc.abstractmethod + @pytest.fixture + def expected_topic(self): + """Return the expected topic string to subscribe to""" + pass + + @pytest.mark.it("Awaits a subscribe to the associated incoming data topic using the MQTTClient") + async def test_mqtt_subscribe(self, mocker, client, method_name, expected_topic): + assert client._mqtt_client.subscribe.await_count == 0 + + method = getattr(client, method_name) + await method() + + assert client._mqtt_client.subscribe.await_count == 1 + assert client._mqtt_client.subscribe.await_args == mocker.call(expected_topic) + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient subscribe to propagate") + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_mqtt_subscribe_exception(self, client, method_name, exception): + client._mqtt_client.subscribe.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + method = getattr(client, method_name) + await method() + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient subscribe to finish") + async def test_cancel(self, client, method_name): + client._mqtt_client.subscribe = custom_mock.HangingAsyncMock() + + method = getattr(client, method_name) + t = asyncio.create_task(method()) + + # Hanging, waiting for MQTT subscribe to finish + await client._mqtt_client.subscribe.wait_for_hang() + assert not t.done() + + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +class IoTHubMQTTClientDisableReceiveTest(abc.ABC): + """Base class for the .disable_x() methods""" + + @abc.abstractmethod + @pytest.fixture + def method_name(self): + """Return the name of the disable method under test""" + pass + + @abc.abstractmethod + @pytest.fixture + def expected_topic(self): + """Return the expected topic string to unsubscribe from""" + pass + + @pytest.mark.it( + "Awaits an unsubscribe from the associated incoming data topic using the MQTTClient" + ) + async def test_mqtt_unsubscribe(self, mocker, client, method_name, expected_topic): + assert client._mqtt_client.unsubscribe.await_count == 0 + + method = getattr(client, method_name) + await method() + + assert client._mqtt_client.unsubscribe.await_count == 1 + assert client._mqtt_client.unsubscribe.await_args == mocker.call(expected_topic) + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient unsubscribe to propagate") + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_mqtt_unsubscribe_exception(self, client, method_name, exception): + client._mqtt_client.unsubscribe.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + method = getattr(client, method_name) + await method() + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient unsubscribe to finish") + async def test_cancel(self, client, method_name): + client._mqtt_client.unsubscribe = custom_mock.HangingAsyncMock() + + method = getattr(client, method_name) + t = asyncio.create_task(method()) + + # Hanging, waiting for MQTT subscribe to finish + await client._mqtt_client.unsubscribe.wait_for_hang() + assert not t.done() + + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubMQTTClient - .enable_c2d_message_receive()") +class TestIoTHubMQTTClientEnableC2DMessageReceive(IoTHubMQTTClientEnableReceiveTest): + @pytest.fixture + def method_name(self): + return "enable_c2d_message_receive" + + @pytest.fixture + def expected_topic(self, client): + return mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + + @pytest.mark.it("Raises IoTHubClientError if client not configured for a Device") + async def test_with_module_id(self, client): + client._module_id = FAKE_MODULE_ID + with pytest.raises(exc.IoTHubClientError): + await client.enable_c2d_message_receive() + + +@pytest.mark.describe("IoTHubMQTTClient - .disable_c2d_message_receive()") +class TestIoTHubMQTTClientDisableC2DMessageReceive(IoTHubMQTTClientDisableReceiveTest): + @pytest.fixture + def method_name(self): + return "disable_c2d_message_receive" + + @pytest.fixture + def expected_topic(self, client): + return mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + + @pytest.mark.it("Raises IoTHubClientError if client not configured for a Device") + async def test_with_module_id(self, client): + client._module_id = FAKE_MODULE_ID + with pytest.raises(exc.IoTHubClientError): + await client.disable_c2d_message_receive() + + +@pytest.mark.describe("IoTHubMQTTClient - .enable_input_message_receive()") +class TestIoTHubMQTTClientEnableInputMessageReceive(IoTHubMQTTClientEnableReceiveTest): + @pytest.fixture(autouse=True) + def modify_client(self, client): + """Add a module ID to the client""" + client._module_id = FAKE_MODULE_ID + + @pytest.fixture + def method_name(self): + return "enable_input_message_receive" + + @pytest.fixture + def expected_topic(self, client): + return mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + + @pytest.mark.it("Raises IoTHubClientError if client not configured for a Module") + async def test_no_module_id(self, client): + client._module_id = None + with pytest.raises(exc.IoTHubClientError): + await client.enable_input_message_receive() + + +@pytest.mark.describe("IoTHubMQTTClient - .disable_input_message_receive()") +class TestIoTHubMQTTClientDisableInputMessageReceive(IoTHubMQTTClientDisableReceiveTest): + @pytest.fixture(autouse=True) + def modify_client(self, client): + """Add a module ID to the client""" + client._module_id = FAKE_MODULE_ID + + @pytest.fixture + def method_name(self): + return "disable_input_message_receive" + + @pytest.fixture + def expected_topic(self, client): + return mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + + @pytest.mark.it("Raises IoTHubClientError if client not configured for a Module") + async def test_no_module_id(self, client): + client._module_id = None + with pytest.raises(exc.IoTHubClientError): + await client.disable_input_message_receive() + + +@pytest.mark.describe("IoTHubMQTTClient - .enable_direct_method_request_receive()") +class TestIoTHubMQTTClientEnableDirectMethodRequestReceive(IoTHubMQTTClientEnableReceiveTest): + @pytest.fixture + def method_name(self): + return "enable_direct_method_request_receive" + + @pytest.fixture + def expected_topic(self): + return mqtt_topic.get_direct_method_request_topic_for_subscribe() + + +@pytest.mark.describe("IoTHubMQTTClient - .disable_direct_method_request_receive()") +class TestIoTHubMQTTClientDisableDirectMethodRequestReceive(IoTHubMQTTClientDisableReceiveTest): + @pytest.fixture + def method_name(self): + return "disable_direct_method_request_receive" + + @pytest.fixture + def expected_topic(self): + return mqtt_topic.get_direct_method_request_topic_for_subscribe() + + +@pytest.mark.describe("IoTHubMQTTClient - .enable_twin_patch_receive()") +class TestIoTHubMQTTClientEnableTwinPatchReceive(IoTHubMQTTClientEnableReceiveTest): + @pytest.fixture + def method_name(self): + return "enable_twin_patch_receive" + + @pytest.fixture + def expected_topic(self): + return mqtt_topic.get_twin_patch_topic_for_subscribe() + + +@pytest.mark.describe("IoTHubMQTTClient - .disable_twin_patch_receive()") +class TestIoTHubMQTTClientDisableTwinPatchReceive(IoTHubMQTTClientDisableReceiveTest): + @pytest.fixture + def method_name(self): + return "disable_twin_patch_receive" + + @pytest.fixture + def expected_topic(self): + return mqtt_topic.get_twin_patch_topic_for_subscribe() + + +@pytest.mark.describe("IoTHubMQTTClient - PROPERTY: .incoming_c2d_messages") +class TestIoTHubMQTTClientIncomingC2DMessages: + @pytest.fixture(autouse=True) + def modify_client_config(self, client_config): + # C2D Messages only work for Device configurations + # NOTE: This has to be changed on the config, not the client, + # because it affects client initialization + client_config.module_id = None + + @pytest.mark.it( + "Is an AsyncGenerator maintained as a read-only property, if using a Device Configuration" + ) + def test_property_device(self, client): + assert client._device_id is not None + assert client._module_id is None + assert isinstance(client.incoming_c2d_messages, typing.AsyncGenerator) + with pytest.raises(AttributeError): + client.incoming_c2d_messages = 12 + + @pytest.mark.it("Raises IoTHubClientError when accessed, if not using Device Configuration") + async def test_property_module(self, client_config): + # Need to modify config and re-instantiate the client here because generators are created + # at instantiation time + client_config.module_id = FAKE_MODULE_ID + client = IoTHubMQTTClient(client_config) + with pytest.raises(exc.IoTHubClientError): + client.incoming_c2d_messages + + @pytest.mark.it( + "Yields a Message whenever the MQTTClient receives an MQTTMessage on the incoming C2D message topic" + ) + async def test_yields_message(self, client): + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = sub_topic.rstrip("#") + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=receive_topic.encode("utf-8")) + # Load the MQTTMessages into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + # Get items from the generator + msg1 = await client.incoming_c2d_messages.__anext__() + assert isinstance(msg1, models.Message) + msg2 = await client.incoming_c2d_messages.__anext__() + assert isinstance(msg2, models.Message) + assert msg1 != msg2 + + @pytest.mark.it( + "Derives the yielded Message payload from the MQTTMessage byte payload according to the content encoding and content type properties contained in the MQTTMessage's topic" + ) + @pytest.mark.parametrize("content_encoding", ["utf-8", "utf-16", "utf-32"]) + @pytest.mark.parametrize( + "content_type, payload_str, expected_payload", + [ + pytest.param("text/plain", "some_payload", "some_payload", id="text/plain"), + pytest.param( + "application/json", '{"some": "json"}', {"some": "json"}, id="application/json" + ), + ], + ) + async def test_payload( + self, client, content_encoding, payload_str, content_type, expected_payload + ): + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#"), + system_properties={"$.ce": content_encoding, "$.ct": content_type}, + custom_properties={}, + ) + # NOTE: topics are always utf-8 encoded, even if the payload is different + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode(content_encoding) + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload != mqtt_msg.payload + assert msg.payload == expected_payload + + @pytest.mark.it( + "Supports conversion to JSON object for any valid JSON string payload when using application/json content type" + ) + @pytest.mark.parametrize( + "str_payload, expected_json_payload", + [ + pytest.param('"String Payload"', "String Payload", id="String Payload"), + pytest.param("1234", 1234, id="Int Payload"), + pytest.param("2.0", 2.0, id="Float Payload"), + pytest.param("true", True, id="Boolean Payload"), + pytest.param("[1, 2, 3]", [1, 2, 3], id="List Payload"), + pytest.param( + '{"some": {"dictionary": "value"}}', + {"some": {"dictionary": "value"}}, + id="Dictionary Payload", + ), + pytest.param("null", None, id="No Payload"), + ], + ) + async def test_application_json_payload(self, client, str_payload, expected_json_payload): + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#"), + system_properties={"$.ct": "application/json"}, + custom_properties={}, + ) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = str_payload.encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload == expected_json_payload + + @pytest.mark.it( + "Uses a default utf-8 codec to decode the MQTTMessage byte payload if no content encoding property is contained in the MQTTMessage's topic" + ) + @pytest.mark.parametrize( + "content_type, payload_str, expected_payload", + [ + pytest.param("text/plain", "some_payload", "some_payload", id="text/plain"), + pytest.param( + "application/json", '{"some": "json"}', {"some": "json"}, id="application/json" + ), + ], + ) + async def test_payload_content_encoding_default( + self, client, content_type, payload_str, expected_payload + ): + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#"), + system_properties={"$.ct": content_type}, + custom_properties={}, + ) + assert "$.ce" not in receive_topic + + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload != mqtt_msg.payload + assert msg.payload == expected_payload + + @pytest.mark.it( + "Treats the payload as text/plain content by default if no content type property is contained in the MQTTMessage's topic" + ) + @pytest.mark.parametrize("content_encoding", ["utf-8", "utf-16", "utf-32"]) + async def test_payload_content_type_default(self, client, content_encoding): + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#"), + system_properties={"$.ce": content_encoding}, + custom_properties={}, + ) + assert "$.ct" not in receive_topic + + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + payload_str = '{"payload": "that", "could": "be", "json": {"or could be": "string"}}' + mqtt_msg.payload = payload_str.encode(content_encoding) + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload != mqtt_msg.payload + assert msg.payload == payload_str + assert msg.payload != json.loads(payload_str) + + @pytest.mark.it( + "Sets any message properties contained in the MQTTMessage's topic onto the resulting Message" + ) + @pytest.mark.parametrize( + "system_properties, custom_properties", + [ + pytest.param( + { + "$.mid": "message_id", + "$.ce": "utf-8", + "$.ct": "text/plain", + "iothub-ack": "ack", + "$.exp": "expiry_time", + "$.uid": "user_id", + "$.cid": "correlation_id", + }, + {}, + id="System Properties Only", + ), + pytest.param( + { + "$.mid": "message_id", + "$.ce": "utf-8", + "$.ct": "text/plain", + "iothub-ack": "ack", + "$.exp": "expiry_time", + "$.uid": "user_id", + "$.cid": "correlation_id", + }, + {"cust_prop1": "value1", "cust_prop2": "value2"}, + id="System Properties and Custom Properties", + ), + ], + ) + async def test_message_properties(self, client, system_properties, custom_properties): + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#"), + system_properties=system_properties, + custom_properties=custom_properties, + ) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = "some payload".encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_c2d_messages.__anext__() + assert msg.get_system_properties_dict() == system_properties + assert msg.custom_properties == custom_properties + + @pytest.mark.it( + "Sets default system properties onto the resulting Message if they are not provided in the MQTTMessage's topic" + ) + async def test_message_property_defaults(self, client): + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#"), + system_properties={}, + custom_properties={}, + ) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_c2d_messages.__anext__() + assert msg.get_system_properties_dict() == {"$.ce": "utf-8", "$.ct": "text/plain"} + assert msg.content_type == "text/plain" + assert msg.content_encoding == "utf-8" + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting the message properties from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_property_extraction_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = sub_topic.rstrip("#") + # MQTTMessage1 + payload1 = "Message #1" + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg1.payload = payload1.encode("utf-8") + # MQTTMessage2 + payload2 = "Message #2" + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = payload2.encode("utf-8") + + # Inject failure in first extraction only + original_fn = mqtt_topic.extract_properties_from_message_topic + mock_extract = mocker.patch.object(mqtt_topic, "extract_properties_from_message_topic") + + def fail_once(*args, **kwargs): + mock_extract.side_effect = original_fn + raise arbitrary_exception + + mock_extract.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload == payload2 + assert payload2 != payload1 + assert mock_extract.call_count == 2 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while decoding the payload from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_payload_decode_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = sub_topic.rstrip("#") + # MQTTMessage 1 (No payload due to mock below) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + # MQTTMessage 2 + payload2 = "Message #2" + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = payload2.encode("utf-8") + + # Inject failure to decode in first MQTTMessage only + mqtt_msg1.payload = mocker.MagicMock() + mqtt_msg1.payload.decode.side_effect = arbitrary_exception + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload == payload2 + assert mqtt_msg1.payload.decode.call_count == 1 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while converting the payload from the MQTTMessage to JSON, dropping the MQTTMessage and continuing" + ) + async def test_json_loads_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#"), + system_properties={"$.ct": "application/json"}, + custom_properties={}, + ) + # MQTTMessage1 + payload1 = {"some": "json"} + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg1.payload = json.dumps(payload1).encode("utf-8") + # MQTTMessage2 + payload2 = {"some_other": "json"} + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload2).encode("utf-8") + + # Inject failure in the first json conversion only + original_loads = json.loads + mock_loads = mocker.patch.object(json, "loads") + + def fail_once(*args, **kwargs): + mock_loads.side_effect = original_loads + raise arbitrary_exception + + mock_loads.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload == payload2 + assert payload2 != payload1 + assert mock_loads.call_count == 2 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while instantiating the Message object from the MQTTMessage values, dropping the MQTTMessage and continuing" + ) + async def test_message_instantiation_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_c2d_topic_for_subscribe(client._device_id) + receive_topic = sub_topic.rstrip("#") + # MQTTMessage1 + payload1 = "Message #1" + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg1.payload = payload1.encode("utf-8") + # MQTTMessage2 + payload2 = "Message #2" + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = payload2.encode("utf-8") + + # Inject failure in first extraction only + original_factory = models.Message.create_from_properties_dict + mock_factory = mocker.patch.object(models.Message, "create_from_properties_dict") + + def fail_once(*args, **kwargs): + mock_factory.side_effect = original_factory + raise arbitrary_exception + + mock_factory.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_c2d_messages.__anext__() + assert msg.payload == payload2 + assert payload2 != payload1 + assert mock_factory.call_count == 2 + + @pytest.mark.skip(reason="Currently can't figure out how to mock a generator correctly") + @pytest.mark.it("Can be cancelled while waiting for an MQTTMessage to arrive") + async def test_cancelled_while_waiting_for_message(self): + pass + + +@pytest.mark.describe("IoTHubMQTTClient - PROPERTY: .incoming_input_messages") +class TestIoTHubMQTTClientIncomingInputMessages: + @pytest.fixture(autouse=True) + def modify_client_config(self, client_config): + # Input Messages only work for Module Configuration. + # NOTE: This has to be changed on the config, not the client, + # because it affects client initialization + client_config.module_id = FAKE_MODULE_ID + + @pytest.mark.it( + "Is an AsyncGenerator maintained as a read-only property, if using a Module Configuration" + ) + def test_property_module(self, client): + assert client._device_id is not None + assert client._module_id is not None + assert isinstance(client.incoming_input_messages, typing.AsyncGenerator) + with pytest.raises(AttributeError): + client.incoming_input_messages = 12 + + @pytest.mark.it("Raises IoTHubClientError when accessed, if not using a Module Configuration") + async def test_property_device(self, client_config): + # Need to modify config and re-instantiate the client here because generators are created + # at instantiation time + client_config.module_id = None + client = IoTHubMQTTClient(client_config) + with pytest.raises(exc.IoTHubClientError): + client.incoming_input_messages + + @pytest.mark.it( + "Yields a Message whenever the MQTTClient receives an MQTTMessage on the incoming Input message topic" + ) + async def test_yields_message(self, client): + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/" + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=receive_topic.encode("utf-8")) + # Load the MQTTMessages into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + # Get items from the generator + msg1 = await client.incoming_input_messages.__anext__() + assert isinstance(msg1, models.Message) + msg2 = await client.incoming_input_messages.__anext__() + assert isinstance(msg2, models.Message) + assert msg1 != msg2 + + @pytest.mark.it( + "Derives the yielded Message payload from the MQTTMessage byte payload according to the content encoding and content type properties contained in the MQTTMessage's topic" + ) + @pytest.mark.parametrize("content_encoding", ["utf-8", "utf-16", "utf-32"]) + @pytest.mark.parametrize( + "content_type, payload_str, expected_payload", + [ + pytest.param("text/plain", "some_payload", "some_payload", id="text/plain"), + pytest.param( + "application/json", '{"some": "json"}', {"some": "json"}, id="application/json" + ), + ], + ) + async def test_payload( + self, client, content_encoding, payload_str, content_type, expected_payload + ): + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/", + system_properties={"$.ce": content_encoding, "$.ct": content_type}, + custom_properties={}, + ) + # NOTE: topics are always utf-8 encoded, even if the payload is different + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode(content_encoding) + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_input_messages.__anext__() + assert msg.payload != mqtt_msg.payload + assert msg.payload == expected_payload + + @pytest.mark.it( + "Supports conversion to JSON object for any valid JSON string payload when using application/json content type" + ) + @pytest.mark.parametrize( + "str_payload, expected_json_payload", + [ + pytest.param('"String Payload"', "String Payload", id="String Payload"), + pytest.param("1234", 1234, id="Int Payload"), + pytest.param("2.0", 2.0, id="Float Payload"), + pytest.param("true", True, id="Boolean Payload"), + pytest.param("[1, 2, 3]", [1, 2, 3], id="List Payload"), + pytest.param( + '{"some": {"dictionary": "value"}}', + {"some": {"dictionary": "value"}}, + id="Dictionary Payload", + ), + pytest.param("null", None, id="No Payload"), + ], + ) + async def test_application_json_payload(self, client, str_payload, expected_json_payload): + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/", + system_properties={"$.ct": "application/json"}, + custom_properties={}, + ) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = str_payload.encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_input_messages.__anext__() + assert msg.payload == expected_json_payload + + @pytest.mark.it( + "Uses a default utf-8 codec to decode the MQTTMessage byte payload if no content encoding property is contained in the MQTTMessage's topic" + ) + @pytest.mark.parametrize( + "content_type, payload_str, expected_payload", + [ + pytest.param("text/plain", "some_payload", "some_payload", id="text/plain"), + pytest.param( + "application/json", '{"some": "json"}', {"some": "json"}, id="application/json" + ), + ], + ) + async def test_payload_content_encoding_default( + self, client, content_type, payload_str, expected_payload + ): + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/", + system_properties={"$.ct": content_type}, + custom_properties={}, + ) + assert "$.ce" not in receive_topic + + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_input_messages.__anext__() + assert msg.payload != mqtt_msg.payload + assert msg.payload == expected_payload + + @pytest.mark.it( + "Treats the payload as text/plain content by default if no content type property is contained in the MQTTMessage's topic" + ) + @pytest.mark.parametrize("content_encoding", ["utf-8", "utf-16", "utf-32"]) + async def test_payload_content_type_default(self, client, content_encoding): + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/", + system_properties={"$.ce": content_encoding}, + custom_properties={}, + ) + assert "$.ct" not in receive_topic + + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + payload_str = '{"payload": "that", "could": "be", "json": {"or could be": "string"}}' + mqtt_msg.payload = payload_str.encode(content_encoding) + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_input_messages.__anext__() + assert msg.payload != mqtt_msg.payload + assert msg.payload == payload_str + assert msg.payload != json.loads(payload_str) + + @pytest.mark.it( + "Sets any message properties contained in the MQTTMessage's topic onto the resulting Message" + ) + @pytest.mark.parametrize( + "system_properties, custom_properties", + [ + pytest.param( + { + "$.mid": "message_id", + "$.to": "some_input", + "$.ce": "utf-8", + "$.ct": "text/plain", + "iothub-ack": "ack", + "$.exp": "expiry_time", + "$.uid": "user_id", + "$.cid": "correlation_id", + }, + {}, + id="System Properties Only", + ), + pytest.param( + { + "$.mid": "message_id", + "$.to": "some_input", + "$.ce": "utf-8", + "$.ct": "text/plain", + "iothub-ack": "ack", + "$.exp": "expiry_time", + "$.uid": "user_id", + "$.cid": "correlation_id", + }, + {"cust_prop1": "value1", "cust_prop2": "value2"}, + id="System Properties and Custom Properties", + ), + ], + ) + async def test_message_properties(self, client, system_properties, custom_properties): + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/", + system_properties=system_properties, + custom_properties=custom_properties, + ) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg.payload = "some payload".encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_input_messages.__anext__() + assert msg.get_system_properties_dict() == system_properties + assert msg.custom_properties == custom_properties + + @pytest.mark.it( + "Sets default system properties onto the resulting Message if they are not provided in the MQTTMessage's topic" + ) + async def test_message_property_defaults(self, client): + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/", + system_properties={}, + custom_properties={}, + ) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg) + msg = await client.incoming_input_messages.__anext__() + assert msg.get_system_properties_dict() == {"$.ce": "utf-8", "$.ct": "text/plain"} + assert msg.content_type == "text/plain" + assert msg.content_encoding == "utf-8" + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting the message properties from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_property_extraction_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/" + # MQTTMessage1 + payload1 = "Message #1" + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg1.payload = payload1.encode("utf-8") + # MQTTMessage2 + payload2 = "Message #2" + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = payload2.encode("utf-8") + + # Inject failure in first extraction only + original_fn = mqtt_topic.extract_properties_from_message_topic + mock_extract = mocker.patch.object(mqtt_topic, "extract_properties_from_message_topic") + + def fail_once(*args, **kwargs): + mock_extract.side_effect = original_fn + raise arbitrary_exception + + mock_extract.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_input_messages.__anext__() + assert msg.payload == payload2 + assert payload2 != payload1 + assert mock_extract.call_count == 2 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while decoding the payload from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_payload_decode_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/" + # MQTTMessage 1 (No payload due to mock below) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + # MQTTMessage 2 + payload2 = "Message #2" + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = payload2.encode("utf-8") + + # Inject failure to decode in first MQTTMessage only + mqtt_msg1.payload = mocker.MagicMock() + mqtt_msg1.payload.decode.side_effect = arbitrary_exception + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_input_messages.__anext__() + assert msg.payload == payload2 + assert mqtt_msg1.payload.decode.call_count == 1 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while converting the payload from the MQTTMessage to JSON, dropping the MQTTMessage and continuing" + ) + async def test_json_loads_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = mqtt_topic.insert_message_properties_in_topic( + topic=sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/", + system_properties={"$.ct": "application/json"}, + custom_properties={}, + ) + # MQTTMessage1 + payload1 = {"some": "json"} + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg1.payload = json.dumps(payload1).encode("utf-8") + # MQTTMessage2 + payload2 = {"some_other": "json"} + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload2).encode("utf-8") + + # Inject failure in the first json conversion only + original_loads = json.loads + mock_loads = mocker.patch.object(json, "loads") + + def fail_once(*args, **kwargs): + mock_loads.side_effect = original_loads + raise arbitrary_exception + + mock_loads.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_input_messages.__anext__() + assert msg.payload == payload2 + assert payload2 != payload1 + assert mock_loads.call_count == 2 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while instantiating the Message object from the MQTTMessage values, dropping the MQTTMessage and continuing" + ) + async def test_message_instantiation_fails(self, mocker, client, arbitrary_exception): + # Create two messages + sub_topic = mqtt_topic.get_input_topic_for_subscribe(client._device_id, client._module_id) + receive_topic = sub_topic.rstrip("#") + FAKE_INPUT_NAME + "/" + # MQTTMessage1 + payload1 = "Message #1" + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg1.payload = payload1.encode("utf-8") + # MQTTMessage2 + payload2 = "Message #2" + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=receive_topic.encode("utf-8")) + mqtt_msg2.payload = payload2.encode("utf-8") + + # Inject failure in first extraction only + original_factory = models.Message.create_from_properties_dict + mock_factory = mocker.patch.object(models.Message, "create_from_properties_dict") + + def fail_once(*args, **kwargs): + mock_factory.side_effect = original_factory + raise arbitrary_exception + + mock_factory.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[sub_topic].put(mqtt_msg2) + + # The Message is derived from the second MQTTMessage instead of the first because the + # first failed, the error was suppressed, and the MQTTMessage discarded + msg = await client.incoming_input_messages.__anext__() + assert msg.payload == payload2 + assert payload2 != payload1 + assert mock_factory.call_count == 2 + + @pytest.mark.skip(reason="Currently can't figure out how to mock a generator correctly") + @pytest.mark.it("Can be cancelled while waiting for an MQTTMessage to arrive") + async def test_cancelled_while_waiting_for_message(self): + pass + + +@pytest.mark.describe("IoTHubMQTTClient - PROPERTY: .incoming_direct_method_requests") +class TestIoTHubMQTTClientIncomingDirectMethodRequests: + @pytest.mark.it("Is an AsyncGenerator maintained as a read-only property") + def test_property(self, client): + assert isinstance(client.incoming_direct_method_requests, typing.AsyncGenerator) + with pytest.raises(AttributeError): + client.incoming_direct_method_requests = 12 + + @pytest.mark.it( + "Yields a DirectMethodRequest whenever the MQTTClient receives an MQTTMessage on the incoming direct method request topic" + ) + async def test_yields_direct_(self, client): + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + + # Create MQTTMessages + mreq1_name = "some_method" + mreq1_id = "12" + mreq1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq1_name, mreq1_id) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq1_topic.encode("utf-8")) + mqtt_msg1.payload = '{"json": "in", "a": {"string": "format"}}'.encode("utf-8") + mreq2_name = "some_other_method" + mreq2_id = "17" + mreq2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq2_name, mreq2_id) + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=mreq2_topic.encode("utf-8")) + mqtt_msg2.payload = '{"json": "in", "a": {"different": {"string": "format"}}}'.encode( + "utf-8" + ) + + # Load the MQTTMessages into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # Get items from generator + mreq1 = await client.incoming_direct_method_requests.__anext__() + assert isinstance(mreq1, models.DirectMethodRequest) + mreq2 = await client.incoming_direct_method_requests.__anext__() + assert isinstance(mreq2, models.DirectMethodRequest) + assert mreq1 != mreq2 + + @pytest.mark.it( + "Extracts the method name and request id from the MQTTMessage's topic and sets them on the resulting DirectMethodRequest" + ) + async def test_direct_method_request_attributes(self, client): + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + + mreq_name = "some_method" + mreq_id = "12" + mreq_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name, mreq_id) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq_topic.encode("utf-8")) + mqtt_msg1.payload = '{"json": "in", "a": {"string": "format"}}'.encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + mreq = await client.incoming_direct_method_requests.__anext__() + + assert mreq.name == mreq_name + assert mreq.request_id == mreq_id + + @pytest.mark.it( + "Derives the yielded DirectMethodRequest JSON payload from the MQTTMessage's byte payload using the utf-8 codec" + ) + async def test_payload(self, client): + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + expected_payload = {"json": "derived", "from": {"byte": "payload"}} + + mreq_name = "some_method" + mreq_id = "12" + mreq_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name, mreq_id) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq_topic.encode("utf-8")) + mqtt_msg1.payload = json.dumps(expected_payload).encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + mreq = await client.incoming_direct_method_requests.__anext__() + + assert mreq.payload == expected_payload + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting the request id from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_request_id_extraction_fails(self, mocker, client, arbitrary_exception): + # Create two messages + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + payload = {"json": "derived", "from": {"byte": "payload"}} + # MQTTMessage 1 + mreq_name1 = "some_method" + mreq_topic1 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name1, 1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq_topic1.encode("utf-8")) + mqtt_msg1.payload = json.dumps(payload).encode("utf-8") + # MQTTMessage 2 + mreq_name2 = "some_other_method" + mreq_topic2 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name2, 2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=mreq_topic2.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload).encode("utf-8") + + # Inject failure into the first extraction only + original_fn = mqtt_topic.extract_request_id_from_direct_method_request_topic + mock_extract = mocker.patch.object( + mqtt_topic, "extract_request_id_from_direct_method_request_topic" + ) + + def fail_once(*args, **kwargs): + mock_extract.side_effect = original_fn + raise arbitrary_exception + + mock_extract.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # The DirectMethodResponse is derived from the second MQTTMessage instead of the first, + # because the first failed, the error was suppressed, and the MQTTMessage discarded + mreq = await client.incoming_direct_method_requests.__anext__() + assert mreq.name == mreq_name2 + assert mreq_name2 != mreq_name1 + assert mock_extract.call_count == 2 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting the method name from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_method_name_extraction_fails(self, mocker, client, arbitrary_exception): + # Create two messages + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + payload = {"json": "derived", "from": {"byte": "payload"}} + # MQTTMessage 1 + mreq_name1 = "some_method" + mreq_topic1 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name1, 1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq_topic1.encode("utf-8")) + mqtt_msg1.payload = json.dumps(payload).encode("utf-8") + # MQTTMessage 2 + mreq_name2 = "some_other_method" + mreq_topic2 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name2, 2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=mreq_topic2.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload).encode("utf-8") + + # Inject failure into the first extraction only + original_fn = mqtt_topic.extract_name_from_direct_method_request_topic + mock_extract = mocker.patch.object( + mqtt_topic, "extract_name_from_direct_method_request_topic" + ) + + def fail_once(*args, **kwargs): + mock_extract.side_effect = original_fn + raise arbitrary_exception + + mock_extract.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # The DirectMethodResponse is derived from the second MQTTMessage instead of the first, + # because the first failed, the error was suppressed, and the MQTTMessage discarded + mreq = await client.incoming_direct_method_requests.__anext__() + assert mreq.name == mreq_name2 + assert mreq_name2 != mreq_name1 + assert mock_extract.call_count == 2 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while decoding the payload from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_payload_decode_fails(self, mocker, client, arbitrary_exception): + # Create two messages + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + # MQTTMessage 1 (No payload due to mock below) + mreq_name1 = "some_method" + mreq_topic1 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name1, 1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq_topic1.encode("utf-8")) + # MQTTMessage 2 + mreq_name2 = "some_other_method" + mreq_topic2 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name2, 2) + payload2 = {"json": "derived", "from": {"byte": "payload"}} + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=mreq_topic2.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload2).encode("utf-8") + + # Inject failure to the first MQTTMessage only + mqtt_msg1.payload = mocker.MagicMock() + mqtt_msg1.payload.decode.side_effect = arbitrary_exception + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # The DirectMethodResponse is derived from the second MQTTMessage instead of the first, + # because the first failed, the error was suppressed, and the MQTTMessage discarded + mreq = await client.incoming_direct_method_requests.__anext__() + assert mreq.name == mreq_name2 + assert mreq_name2 != mreq_name1 + assert mqtt_msg1.payload.decode.call_count == 1 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while converting the payload from the MQTTMessage to JSON, dropping the MQTTMessage and continuing" + ) + async def test_json_loads_fails(self, mocker, client, arbitrary_exception): + # Create two messages + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + payload = {"json": "derived", "from": {"byte": "payload"}} + # MQTTMessage 1 + mreq_name1 = "some_method" + mreq_topic1 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name1, 1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq_topic1.encode("utf-8")) + mqtt_msg1.payload = json.dumps(payload).encode("utf-8") + # MQTTMessage 2 + mreq_name2 = "some_other_method" + mreq_topic2 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name2, 2) + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=mreq_topic2.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload).encode("utf-8") + + # Inject failure to the first json conversion only + original_loads = json.loads + mock_loads = mocker.patch.object(json, "loads") + + def fail_once(*args, **kwargs): + mock_loads.side_effect = original_loads + raise arbitrary_exception + + mock_loads.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # The DirectMethodResponse is derived from the second MQTTMessage instead of the first, + # because the first failed, the error was suppressed, and the MQTTMessage discarded + mreq = await client.incoming_direct_method_requests.__anext__() + assert mreq.name == mreq_name2 + assert mreq_name2 != mreq_name1 + assert mock_loads.call_count == 2 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while instantiating the DirectMethodRequest object from the MQTTMessage values, dropping the MQTTMessage and continuing" + ) + async def test_request_instantiation_fails(self, mocker, client, arbitrary_exception): + # Create two messages + generic_topic = mqtt_topic.get_direct_method_request_topic_for_subscribe() + payload = {"json": "derived", "from": {"byte": "payload"}} + # MQTTMessage 1 + mreq_name1 = "some_method" + mreq_topic1 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name1, 1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=mreq_topic1.encode("utf-8")) + mqtt_msg1.payload = json.dumps(payload).encode("utf-8") + # MQTTMessage 2 + mreq_name2 = "some_other_method" + mreq_topic2 = generic_topic.rstrip("#") + "{}/?$rid={}".format(mreq_name2, 2) + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=mreq_topic2.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload).encode("utf-8") + + # Inject failure into the first instantiation only + original_cls = models.DirectMethodRequest + mock_cls = mocker.patch.object(models, "DirectMethodRequest") + + def fail_once(*args, **kwargs): + mock_cls.side_effect = original_cls + raise arbitrary_exception + + mock_cls.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # The DirectMethodResponse is derived from the second MQTTMessage instead of the first, + # because the first failed, the error was suppressed, and the MQTTMessage discarded + mreq = await client.incoming_direct_method_requests.__anext__() + assert mreq.name == mreq_name2 + assert mreq_name2 != mreq_name1 + assert mock_cls.call_count == 2 + + @pytest.mark.skip(reason="Currently can't figure out how to mock a generator correctly") + @pytest.mark.it("Can be cancelled while waiting for an MQTTMessage to arrive") + async def test_cancelled_while_waiting_for_message(self): + pass + + +@pytest.mark.describe("IoTHubMQTTClient - PROPERTY: .incoming_twin_patches") +class TestIoTHubMQTTClientIncomingTwinPatches: + @pytest.mark.it("Is an AsyncGenerator maintained as a read-only property") + def test_property(self, client): + assert isinstance(client.incoming_twin_patches, typing.AsyncGenerator) + with pytest.raises(AttributeError): + client.incoming_twin_patches = 12 + + @pytest.mark.it( + "Yields a JSON-formatted dictionary whenever the MQTTClient receives an MQTTMessage on the incoming twin patch topic" + ) + async def test_yields_twin(self, client): + generic_topic = mqtt_topic.get_twin_patch_topic_for_subscribe() + patch1_topic = generic_topic.rstrip("#") + "?$version=1" + patch2_topic = generic_topic.rstrip("#") + "?$version=2" + # Create MQTTMessages + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=patch1_topic.encode("utf-8")) + mqtt_msg1.payload = '{"property1": "value1", "property2": "value2", "$version": 1}'.encode( + "utf-8" + ) + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=patch2_topic.encode("utf-8")) + mqtt_msg2.payload = '{"property1": "value3", "property2": "value4", "$version": 2}'.encode( + "utf-8" + ) + # Load the MQTTMessages into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + # Get items form generator + patch1 = await client.incoming_twin_patches.__anext__() + assert isinstance(patch1, dict) + assert json.dumps(patch1) # This would fail if it's not valid JSON + patch2 = await client.incoming_twin_patches.__anext__() + assert isinstance(patch2, dict) + assert json.dumps(patch1) # This would fail if it's not valid JSON + + @pytest.mark.it( + "Derives the yielded JSON-formatted dictionary from the MQTTMessage's byte payload using the utf-8 codec" + ) + async def test_payload(self, client): + generic_topic = mqtt_topic.get_twin_patch_topic_for_subscribe() + patch_topic = generic_topic.rstrip("#") + "?$version=1" + expected_json = {"property1": "value1", "property2": "value2", "$version": 1} + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=patch_topic.encode("utf-8")) + mqtt_msg.payload = json.dumps(expected_json).encode("utf-8") + + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + patch = await client.incoming_twin_patches.__anext__() + assert patch == expected_json + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while decoding the payload from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_payload_decode_fails(self, mocker, client, arbitrary_exception): + # Create two messages + generic_topic = mqtt_topic.get_twin_patch_topic_for_subscribe() + # MQTTMessage 1 (no payload due to mock below) + patch_topic1 = generic_topic.rstrip("#") + "?$version=1" + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=patch_topic1.encode("utf-8")) + # MQTTMessage 2 + patch_topic2 = generic_topic.rstrip("#") + "?$version=2" + payload2 = {"property1": "value1", "property2": "value2", "$version": 2} + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=patch_topic2.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload2).encode("utf-8") + + # Inject failure to the first MQTTMessage only + mqtt_msg1.payload = mocker.MagicMock() + mqtt_msg1.payload.decode.side_effect = arbitrary_exception + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # The twin patch is derived from the second message instead of the first, because the first + # failed, the error was suppressed, and the message discarded + patch = await client.incoming_twin_patches.__anext__() + assert patch == payload2 + assert mqtt_msg1.payload.decode.call_count == 1 + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while converting the payload from the MQTTMessage to JSON, dropping the MQTTMessage and continuing" + ) + async def test_json_loads_fails(self, mocker, client, arbitrary_exception): + # Create two messages + generic_topic = mqtt_topic.get_twin_patch_topic_for_subscribe() + # MQTTMessage 1 + patch_topic1 = generic_topic.rstrip("#") + "?$version=1" + payload1 = {"property1": "value1", "property2": "value2", "$version": 1} + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=patch_topic1.encode("utf-8")) + mqtt_msg1.payload = json.dumps(payload1).encode("utf-8") + # MQTTMessage 2 + patch_topic2 = generic_topic.rstrip("#") + "?$version=2" + payload2 = {"property1": "value1", "property2": "value2", "$version": 2} + mqtt_msg2 = mqtt.MQTTMessage(mid=2, topic=patch_topic2.encode("utf-8")) + mqtt_msg2.payload = json.dumps(payload2).encode("utf-8") + + # Inject failure to the first json conversion only + original_loads = json.loads + mock_loads = mocker.patch.object(json, "loads") + + def fail_once(*args, **kwargs): + mock_loads.side_effect = original_loads + raise arbitrary_exception + + mock_loads.side_effect = fail_once + + # Load the MQTTMessages + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + + # The twin patch is derived from the second message instead of the first, because the first + # failed, the error was suppressed, and the message discarded + patch = await client.incoming_twin_patches.__anext__() + assert patch == payload2 + assert payload1 != payload2 + assert mock_loads.call_count == 2 + + @pytest.mark.skip(reason="Currently can't figure out how to mock a generator correctly") + @pytest.mark.it("Can be cancelled while waiting for an MQTTMessage to arrive") + async def test_cancelled_while_waiting_for_message(self): + pass + + +@pytest.mark.describe("IoTHubMQTTClient - PROPERTY: .connected") +class TestIoTHubMQTTClientConnected: + @pytest.mark.it("Returns the result of the MQTTClient's .is_connected() method") + def test_returns_result(self, mocker, client): + assert client._mqtt_client.is_connected.call_count == 0 + + result = client.connected + + assert client._mqtt_client.is_connected.call_count == 1 + assert client._mqtt_client.is_connected.call_args == mocker.call() + assert result is client._mqtt_client.is_connected.return_value + + +@pytest.mark.describe("IoTHubMQTTClient - BG TASK: ._process_twin_responses") +class TestIoTHubMQTTClientProcessTwinResponses: + response_payloads = [ + pytest.param('{"json": "in", "a": {"string": "format"}}', id="Get Twin Response"), + pytest.param(" ", id="Twin Patch Response"), + ] + + @pytest.mark.it( + "Creates a Response containing the request id and status code from the topic, as well as the utf-8 decoded payload of the MQTTMessage, when the MQTTClient receives an MQTTMessage on the twin response topic" + ) + @pytest.mark.parametrize( + "status", + [ + pytest.param(200, id="Status Code: 200"), + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + @pytest.mark.parametrize("payload_str", response_payloads) + async def test_response(self, mocker, client, status, payload_str): + # Mocks + mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + spy_response_factory = mocker.spy(rr, "Response") + # Set up MQTTMessages + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + rid = "some rid" + msg_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(status, rid) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=msg_topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode("utf-8") + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + await asyncio.sleep(0.1) + + # No Responses have been created yet + assert spy_response_factory.call_count == 0 + + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + await asyncio.sleep(0.1) + + # Response was created + assert spy_response_factory.call_count == 1 + resp1 = spy_response_factory.spy_return + assert resp1.request_id == rid + assert resp1.status == status + assert resp1.body == payload_str + + t.cancel() + + @pytest.mark.it("Matches the newly created Response on the RequestLedger") + @pytest.mark.parametrize("payload_str", response_payloads) + async def test_match(self, mocker, client, payload_str): + # Mock + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + spy_response_factory = mocker.spy(rr, "Response") + # Set up MQTTMessage + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, "some rid") + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode("utf-8") + + # No Responses have been created yet + assert spy_response_factory.call_count == 0 + assert mock_ledger.match_response.call_count == 0 + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + await asyncio.sleep(0.1) + + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + await asyncio.sleep(0.1) + + # Response was created + assert spy_response_factory.call_count == 1 + resp1 = spy_response_factory.spy_return + assert mock_ledger.match_response.call_count == 1 + assert mock_ledger.match_response.call_args == mocker.call(resp1) + + t.cancel() + + @pytest.mark.it("Indefinitely repeats") + async def test_repeat(self, mocker, client): + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + spy_response_factory = mocker.spy(rr, "Response") + assert spy_response_factory.call_count == 0 + assert mock_ledger.match_response.call_count == 0 + + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, "some rid") + + t = asyncio.create_task(client._process_twin_responses()) + await asyncio.sleep(0.1) + + # Test that behavior repeats up to 10 times. No way to really prove infinite + i = 0 + assert mock_ledger.match_response.call_count == 0 + while i < 10: + i += 1 + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=topic.encode("utf-8")) + # Switch between Get Twin and Twin Patch responses + if i % 2 == 0: + mqtt_msg.payload = " ".encode("utf-8") + else: + mqtt_msg.payload = '{"json": "in", "a": {"string": "format"}}'.encode("utf-8") + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + await asyncio.sleep(0.1) + # Response was created + assert spy_response_factory.call_count == i + + assert not t.done() + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting the request id from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_request_id_extraction_fails(self, mocker, client, arbitrary_exception): + # Inject failure + original_fn = mqtt_topic.extract_request_id_from_twin_response_topic + mocker.patch.object( + mqtt_topic, + "extract_request_id_from_twin_response_topic", + side_effect=arbitrary_exception, + ) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Spy on the Response object + spy_response_factory = mocker.spy(rr, "Response") + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response was created due to the injected failure (but failure was suppressed) + assert spy_response_factory.call_count == 0 + mqtt_topic.extract_request_id_from_twin_response_topic.call_count == 1 + + # Un-inject the failure + mqtt_topic.extract_request_id_from_twin_response_topic = original_fn + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created, demonstrating that the task is still functional + assert spy_response_factory.call_count == 1 + resp2 = spy_response_factory.spy_return + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting the status code from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_status_code_extraction_fails(self, mocker, client, arbitrary_exception): + # Inject failure + original_fn = mqtt_topic.extract_status_code_from_twin_response_topic + mocker.patch.object( + mqtt_topic, + "extract_status_code_from_twin_response_topic", + side_effect=arbitrary_exception, + ) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Spy on the Response object + spy_response_factory = mocker.spy(rr, "Response") + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response was created due to the injected failure (but failure was suppressed) + assert spy_response_factory.call_count == 0 + mqtt_topic.extract_status_code_from_twin_response_topic.call_count == 1 + + # Un-inject the failure + mqtt_topic.extract_status_code_from_twin_response_topic = original_fn + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created, demonstrating that the previous failure did not + # crash the task + assert spy_response_factory.call_count == 1 + resp2 = spy_response_factory.spy_return + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while decoding the payload from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_payload_decode_fails(self, mocker, client, arbitrary_exception): + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Spy on the Response object + spy_response_factory = mocker.spy(rr, "Response") + + # Inject failure into the first MQTTMessage's payload + mqtt_msg1.payload = mocker.MagicMock() + mqtt_msg1.payload.decode.side_effect = arbitrary_exception + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response was created due to the injected failure (but failure was suppressed) + assert spy_response_factory.call_count == 0 + assert mqtt_msg1.payload.decode.call_count == 1 + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created, demonstrating that the previous failure did not + # crash the task + assert spy_response_factory.call_count == 1 + resp2 = spy_response_factory.spy_return + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised instantiating the Response object from the MQTTMessage values, dropping the MQTTMessage and continuing" + ) + async def test_response_instantiation_fails(self, mocker, client, arbitrary_exception): + # Inject failure + original_cls = rr.Response + mocker.patch.object(rr, "Response", side_effect=arbitrary_exception) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Mock the ledger so we can see if it is used + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response matched to the injected failure (but failure was suppressed) + assert mock_ledger.match_response.call_count == 0 + assert rr.Response.call_count == 1 + + # Un-inject the failure + rr.Response = original_cls + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created and matched, demonstrating that the previous + # failure did not crash the task + assert mock_ledger.match_response.call_count == 1 + resp = mock_ledger.match_response.call_args[0][0] + assert resp.request_id == rid2 + assert resp.status == 200 + assert resp.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any exceptions raised while matching the Response, dropping the MQTTMessage and continuing" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(KeyError(), id="KeyError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_response_match_fails(self, mocker, client, exception): + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Inject failure into the response match + mock_ledger.match_response.side_effect = exception + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # Attempt to match response ocurred (and thus, failed, due to mock) + assert mock_ledger.match_response.call_count == 1 + + # Un-inject the failure + mock_ledger.match_response.side_effect = None + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # Another response match ocurred, demonstrating that the previous failure did not + # crash the task + assert mock_ledger.match_response.call_count == 2 + resp2 = mock_ledger.match_response.call_args[0][0] + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.skip(reason="Currently can't figure out how to mock a generator correctly") + @pytest.mark.it("Can be cancelled while waiting for an MQTTMessage to arrive") + async def test_cancelled_while_waiting_for_message(self): + pass + + @pytest.mark.it("Can be cancelled while matching Response") + async def test_cancelled_while_matching_response(self, mocker, client): + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + mock_ledger.match_response = custom_mock.HangingAsyncMock() + + # Set up MQTTMessage + generic_topic = mqtt_topic.get_twin_response_topic_for_subscribe() + topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, "some rid") + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=topic.encode("utf-8")) + mqtt_msg.payload = " ".encode("utf-8") + + # Start task + t = asyncio.create_task(client._process_twin_responses()) + await asyncio.sleep(0.1) + + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + + # Matching response is hanging + await mock_ledger.match_response.wait_for_hang() + + # Task can be cancelled + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t diff --git a/tests/unit/test_iothub_session.py b/tests/unit/test_iothub_session.py new file mode 100644 index 000000000..01001d8c9 --- /dev/null +++ b/tests/unit/test_iothub_session.py @@ -0,0 +1,2507 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import pytest +import ssl +import time +import typing +from dev_utils import custom_mock +from pytest_lazyfixture import lazy_fixture +from azure.iot.device.iothub_session import IoTHubSession +from azure.iot.device import config, models +from azure.iot.device import connection_string as cs +from azure.iot.device import exceptions as exc +from azure.iot.device import iothub_mqtt_client as mqtt +from azure.iot.device import sastoken as st +from azure.iot.device import signing_mechanism as sm + +FAKE_DEVICE_ID = "fake_device_id" +FAKE_MODULE_ID = "fake_module_id" +FAKE_HOSTNAME = "fake.hostname" +FAKE_URI = "fake/resource/location" +FAKE_SHARED_ACCESS_KEY = "Zm9vYmFy" +FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" + +# ~~~~~ Helpers ~~~~~~ + + +def sastoken_generator_fn(): + return "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( + resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=str(int(time.time()) + 3600) + ) + + +def get_expected_uri(hostname, device_id, module_id): + if module_id: + return "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=hostname, device_id=device_id, module_id=module_id + ) + else: + return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) + + +# ~~~~~ Fixtures ~~~~~~ + +# Mock out the underlying client in order to not do network operations +@pytest.fixture(autouse=True) +def mock_mqtt_iothub_client(mocker): + mock_client = mocker.patch.object( + mqtt, "IoTHubMQTTClient", spec=mqtt.IoTHubMQTTClient + ).return_value + # Use a HangingAsyncMock here so that the coroutine does not return until we want it to + mock_client.wait_for_disconnect = custom_mock.HangingAsyncMock() + return mock_client + + +@pytest.fixture(autouse=True) +def mock_sastoken_provider(mocker): + return mocker.patch.object(st, "SasTokenProvider", spec=st.SasTokenProvider).return_value + + +@pytest.fixture +def custom_ssl_context(): + # NOTE: It doesn't matter how the SSLContext is configured for the tests that use this fixture, + # so it isn't configured at all. + return ssl.SSLContext() + + +@pytest.fixture(params=["Default SSLContext", "Custom SSLContext"]) +def optional_ssl_context(request, custom_ssl_context): + """Sometimes tests need to show something works with or without an SSLContext""" + if request.param == "Custom SSLContext": + return custom_ssl_context + else: + return None + + +@pytest.fixture +async def session(custom_ssl_context): + """Use a device configuration and custom SSL auth for simplicity""" + async with IoTHubSession( + hostname=FAKE_HOSTNAME, device_id=FAKE_DEVICE_ID, ssl_context=custom_ssl_context + ) as session: + yield session + + +@pytest.fixture +def disconnected_session(custom_ssl_context): + return IoTHubSession( + hostname=FAKE_HOSTNAME, device_id=FAKE_DEVICE_ID, ssl_context=custom_ssl_context + ) + + +# ~~~~~ Parametrizations ~~~~~ +# Define parametrizations that will be used across multiple test suites, and that may eventually +# need to be changed everywhere, e.g. new auth scheme added. +# Note that some parametrizations are also defined within the scope of a single test suite if that +# is the only unit they are relevant to. + + +# Parameters for arguments to the __init__ or factory methods. Represent different types of +# authentication. Use this parametrization whenever possible on .create() tests. +# NOTE: Do NOT combine this with the SSL fixtures above. This parametrization contains +# ssl contexts where necessary +create_auth_params = [ + # Provide args in form 'shared_access_key, sastoken_fn, ssl_context' + pytest.param( + FAKE_SHARED_ACCESS_KEY, None, None, id="Shared Access Key SAS Auth + Default SSLContext" + ), + pytest.param( + FAKE_SHARED_ACCESS_KEY, + None, + lazy_fixture("custom_ssl_context"), + id="Shared Access Key SAS Auth + Custom SSLContext", + ), + pytest.param( + None, + sastoken_generator_fn, + None, + id="User-Provided SAS Token Auth + Default SSLContext", + ), + pytest.param( + None, + sastoken_generator_fn, + lazy_fixture("custom_ssl_context"), + id="User-Provided SAS Token Auth + Custom SSLContext", + ), + pytest.param(None, None, lazy_fixture("custom_ssl_context"), id="Custom SSLContext Auth"), +] +# Just the parameters where SAS auth is used +create_auth_params_sas = [param for param in create_auth_params if "SAS" in param.id] +# Just the parameters where a Shared Access Key auth is used +create_auth_params_sak = [param for param in create_auth_params if param.values[0] is not None] +# Just the parameters where SAS callback auth is used +create_auth_params_token_cb = [param for param in create_auth_params if param.values[1] is not None] +# Just the parameters where a custom SSLContext is provided +create_auth_params_custom_ssl = [ + param for param in create_auth_params if param.values[2] is not None +] +# Just the parameters where a custom SSLContext is NOT provided +create_auth_params_default_ssl = [param for param in create_auth_params if param.values[2] is None] + + +# Covers all option kwargs shared across client factory methods +factory_kwargs = [ + # pytest.param("auto_reconnect", False, id="auto_reconnect"), + pytest.param("keep_alive", 34, id="keep_alive"), + pytest.param("product_info", "fake-product-info", id="product_info"), + pytest.param( + "proxy_options", config.ProxyOptions("HTTP", "fake.address", 1080), id="proxy_options" + ), + pytest.param("websockets", True, id="websockets"), +] + +sk_sm_create_exceptions = [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] + +# Does the session exit gracefully or because of error? +graceful_exit_params = [ + pytest.param(True, id="graceful exit"), + pytest.param(False, id="exit because of exception"), +] + + +@pytest.mark.describe("IoTHubSession -- Instantiation") +class TestIoTHubSessionInstantiation: + create_id_params = [ + # Provide args in the form 'device_id, module_id' + pytest.param(FAKE_DEVICE_ID, None, id="Device"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module"), + ] + + @pytest.mark.it( + "Instantiates and stores a SasTokenProvider that uses symmetric key-based token generation, if `shared_access_key` is provided" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params_sak) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_sak_auth( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + assert sastoken_fn is None + spy_sk_sm_cls = mocker.spy(sm, "SymmetricKeySigningMechanism") + spy_st_generator_cls = mocker.spy(st, "InternalSasTokenGenerator") + spy_st_provider_cls = mocker.spy(st, "SasTokenProvider") + expected_uri = get_expected_uri(FAKE_HOSTNAME, device_id, module_id) + + session = IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + ) + + # SymmetricKeySigningMechanism was created from the shared access key + assert spy_sk_sm_cls.call_count == 1 + assert spy_sk_sm_cls.call_args == mocker.call(shared_access_key) + # InternalSasTokenGenerator was created from the SymmetricKeySigningMechanism + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call( + signing_mechanism=spy_sk_sm_cls.spy_return, uri=expected_uri, ttl=3600 + ) + # SasTokenProvider was created from the InternalSasTokenGenerator + assert spy_st_provider_cls.call_count == 1 + assert spy_st_provider_cls.call_args == mocker.call(spy_st_generator_cls.spy_return) + # SasTokenProvider was set on the Session + assert session._sastoken_provider is spy_st_provider_cls.spy_return + + @pytest.mark.it( + "Instantiates and stores a SasTokenProvider that uses callback-based token generation, if `sastoken_fn` is provided" + ) + @pytest.mark.parametrize( + "shared_access_key, sastoken_fn, ssl_context", create_auth_params_token_cb + ) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_token_callback_auth( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + assert shared_access_key is None + spy_st_generator_cls = mocker.spy(st, "ExternalSasTokenGenerator") + spy_st_provider_cls = mocker.spy(st, "SasTokenProvider") + + session = IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + # ExternalSasTokenGenerator was created from the sastoken_fn + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call(sastoken_fn) + # SasTokenProvider was created from the ExternalSasTokenGenerator + assert spy_st_provider_cls.call_count == 1 + assert spy_st_provider_cls.call_args == mocker.call(spy_st_generator_cls.spy_return) + # SasTokenProvider was set on the Session + assert session._sastoken_provider is spy_st_provider_cls.spy_return + + @pytest.mark.it( + "Does not instantiate or store any SasTokenProvider if neither `shared_access_key` nor `sastoken_fn` are provided" + ) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_non_sas_auth(self, mocker, device_id, module_id, custom_ssl_context): + spy_st_provider_cls = mocker.spy(st, "SasTokenProvider") + + session = IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + ssl_context=custom_ssl_context, + ) + + # No SasTokenProvider + assert session._sastoken_provider is None + assert spy_st_provider_cls.call_count == 0 + + @pytest.mark.it( + "Instantiates and stores an IoTHubMQTTClient, using a new IoTHubClientConfig object" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_mqtt_client( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + spy_config_cls = mocker.spy(config, "IoTHubClientConfig") + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + assert spy_config_cls.call_count == 0 + assert spy_mqtt_cls.call_count == 0 + + session = IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_config_cls.call_count == 1 + assert spy_mqtt_cls.call_count == 1 + assert spy_mqtt_cls.call_args == mocker.call(spy_config_cls.spy_return) + assert session._mqtt_client is spy_mqtt_cls.spy_return + + @pytest.mark.it( + "Sets the provided `hostname` on the IoTHubClientConfig used to create the IoTHubMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_hostname( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.hostname == FAKE_HOSTNAME + + @pytest.mark.it( + "Sets the provided `device_id` and `module_id` values on the IoTHubClientConfig used to create the IoTHubMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_ids( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.device_id == device_id + assert cfg.module_id == module_id + + @pytest.mark.it( + "Sets the provided `ssl_context` on the IoTHubClientConfig used to create the IoTHubMQTTClient, if provided" + ) + @pytest.mark.parametrize( + "shared_access_key, sastoken_fn, ssl_context", create_auth_params_custom_ssl + ) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_custom_ssl_context( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.ssl_context is ssl_context + + @pytest.mark.it( + "Sets a default SSLContext on the IoTHubClientConfig used to create the IoTHubMQTTClient, if `ssl_context` is not provided" + ) + @pytest.mark.parametrize( + "shared_access_key, sastoken_fn, ssl_context", create_auth_params_default_ssl + ) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_default_ssl_context( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + assert ssl_context is None + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + my_ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + original_ssl_ctx_cls = ssl.SSLContext + + # NOTE: SSLContext is difficult to mock as an entire class, due to how it implements + # instantiation. Essentially, if you mock the entire class, it will not be able to + # instantiate due to an internal reference to the class type, which of course has now been + # changed to MagicMock. To get around this, we mock the class with a side effect that can + # check the arguments passed to the constructor, return a pre-existing SSLContext, and then + # unset the mock to prevent future issues. + def return_and_reset(*args, **kwargs): + ssl.SSLContext = original_ssl_ctx_cls + assert kwargs["protocol"] is ssl.PROTOCOL_TLS_CLIENT + return my_ssl_context + + mocker.patch.object(ssl, "SSLContext", side_effect=return_and_reset) + mocker.spy(my_ssl_context, "load_default_certs") + + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + ctx = cfg.ssl_context + assert ctx is my_ssl_context + # NOTE: ctx protocol is checked in the `return_and_reset` side effect above + assert ctx.verify_mode == ssl.CERT_REQUIRED + assert ctx.check_hostname is True + assert ctx.load_default_certs.call_count == 1 + assert ctx.load_default_certs.call_args == mocker.call() + assert ctx.minimum_version == ssl.TLSVersion.TLSv1_2 + + @pytest.mark.it( + "Sets the stored SasTokenProvider (if any) on the IoTHubClientConfig used to create the IoTHubMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_sastoken_provider_cfg( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + + session = IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.sastoken_provider is session._sastoken_provider + + @pytest.mark.it( + "Sets `auto_reconnect` to False on the IoTHubClientConfig used to create the IoTHubMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_auto_reconnect_cfg( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.auto_reconnect is False + + @pytest.mark.it( + "Sets any provided optional keyword arguments on the IoTHubClientConfig used to create the IoTHubMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_kwargs( + self, + mocker, + device_id, + module_id, + shared_access_key, + sastoken_fn, + ssl_context, + kwarg_name, + kwarg_value, + ): + spy_mqtt_cls = mocker.spy(mqtt, "IoTHubMQTTClient") + kwargs = {kwarg_name: kwarg_value} + + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + **kwargs + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert getattr(cfg, kwarg_name) == kwarg_value + + @pytest.mark.it("Sets the `wait_for_disconnect_task` attribute to None") + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_wait_for_disconnect_task( + self, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + session = IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + assert session._wait_for_disconnect_task is None + + @pytest.mark.it( + "Raises ValueError if neither `shared_access_key`, `sastoken_fn` nor `ssl_context` are provided as parameters" + ) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_no_auth(self, device_id, module_id): + with pytest.raises(ValueError): + IoTHubSession( + device_id=device_id, + module_id=module_id, + hostname=FAKE_HOSTNAME, + ) + + @pytest.mark.it( + "Raises ValueError if both `shared_access_key` and `sastoken_fn` are provided as parameters" + ) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_conflicting_auth(self, device_id, module_id, optional_ssl_context): + with pytest.raises(ValueError): + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + sastoken_fn=sastoken_generator_fn, + ssl_context=optional_ssl_context, + ) + + @pytest.mark.it("Raises TypeError if an invalid keyword argument is provided") + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_bad_kwarg( + self, device_id, module_id, shared_access_key, sastoken_fn, ssl_context + ): + with pytest.raises(TypeError): + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + invalid_argument="some value", + ) + + @pytest.mark.it( + "Allows any exceptions raised when creating a SymmetricKeySigningMechanism to propagate" + ) + @pytest.mark.parametrize("exception", sk_sm_create_exceptions) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params_sak) + @pytest.mark.parametrize("device_id, module_id", create_id_params) + async def test_sksm_raises( + self, mocker, device_id, module_id, shared_access_key, sastoken_fn, ssl_context, exception + ): + mocker.patch.object(sm, "SymmetricKeySigningMechanism", side_effect=exception) + assert sastoken_fn is None + + with pytest.raises(type(exception)) as e_info: + IoTHubSession( + hostname=FAKE_HOSTNAME, + device_id=device_id, + module_id=module_id, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + +@pytest.mark.describe("IoTHubSession - .from_connection_string") +class TestIoTHubSessionFromConnectionString: + factory_params = [ + # TODO: once Edge support is decided upon, either re-enable, or remove the commented Edge parameters + # TODO: Do we want gateway hostname tests that are non-Edge? probably? + pytest.param( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + ), + None, + id="Standard Device Connection String w/ SharedAccessKey + Default SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Device Connection String w/ SharedAccessKey + Custom SSLContext", + ), + # pytest.param( + # "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + # hostname=FAKE_HOSTNAME, + # device_id=FAKE_DEVICE_ID, + # shared_access_key=FAKE_SHARED_ACCESS_KEY, + # gateway_hostname=FAKE_GATEWAY_HOSTNAME, + # ), + # None, + # id="Edge Device Connection String w/ SharedAccessKey + Default SSLContext", + # ), + # pytest.param( + # "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + # hostname=FAKE_HOSTNAME, + # device_id=FAKE_DEVICE_ID, + # shared_access_key=FAKE_SHARED_ACCESS_KEY, + # gateway_hostname=FAKE_GATEWAY_HOSTNAME, + # ), + # lazy_fixture("custom_ssl_context"), + # id="Edge Device Connection String w/ SharedAccessKey + Custom SSLContext", + # ), + # NOTE: X509 certs imply use of custom SSLContext + pytest.param( + "HostName={hostname};DeviceId={device_id};x509=true".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Device Connection String w/ X509", + ), + # pytest.param( + # "HostName={hostname};DeviceId={device_id};GatewayHostName={gateway_hostname};x509=true".format( + # hostname=FAKE_HOSTNAME, + # device_id=FAKE_DEVICE_ID, + # gateway_hostname=FAKE_GATEWAY_HOSTNAME, + # ), + # lazy_fixture("custom_ssl_context"), + # id="Edge Device Connection String w/ X509", + # ), + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + ), + None, + id="Standard Module Connection String w/ SharedAccessKey + Default SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Module Connection String w/ SharedAccessKey + Custom SSLContext", + ), + # pytest.param( + # "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + # hostname=FAKE_HOSTNAME, + # device_id=FAKE_DEVICE_ID, + # module_id=FAKE_MODULE_ID, + # shared_access_key=FAKE_SHARED_ACCESS_KEY, + # gateway_hostname=FAKE_GATEWAY_HOSTNAME, + # ), + # None, + # id="Edge Module Connection String w/ SharedAccessKey + Default SSLContext", + # ), + # pytest.param( + # "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + # hostname=FAKE_HOSTNAME, + # device_id=FAKE_DEVICE_ID, + # module_id=FAKE_MODULE_ID, + # shared_access_key=FAKE_SHARED_ACCESS_KEY, + # gateway_hostname=FAKE_GATEWAY_HOSTNAME, + # ), + # lazy_fixture("custom_ssl_context"), + # id="Edge Module Connection String w/ SharedAccessKey + Custom SSLContext", + # ), + # NOTE: X509 certs imply use of custom SSLContext + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};x509=true".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Module Connection String w/ X509", + ), + # pytest.param( + # "HostName={hostname};DeviceId={device_id};ModuleId={module_id};GatewayHostName={gateway_hostname};x509=true".format( + # hostname=FAKE_HOSTNAME, + # device_id=FAKE_DEVICE_ID, + # module_id=FAKE_MODULE_ID, + # gateway_hostname=FAKE_GATEWAY_HOSTNAME, + # ), + # lazy_fixture("custom_ssl_context"), + # id="Edge Module Connection String w/ X509", + # ), + ] + # Just the parameters for using standard connection strings + factory_params_no_gateway = [ + param for param in factory_params if cs.GATEWAY_HOST_NAME not in param.values[0] + ] + # Just the parameters for using connection strings with a GatewayHostName + factory_params_gateway = [ + param for param in factory_params if cs.GATEWAY_HOST_NAME in param.values[0] + ] + # # Just the parameters where a custom SSLContext is provided + # factory_params_custom_ssl = [param for param in factory_params if param.values[1] is not None] + # # Just the parameters where a custom SSLContext is NOT provided + # factory_params_default_ssl = [param for param in factory_params if param.values[1] is None] + # # Just the parameters for using SharedAccessKeys + # factory_params_sak = [ + # param for param in factory_params if cs.SHARED_ACCESS_KEY in param.values[0] + # ] + # Just the parameters for NOT using SharedAccessKeys + factory_params_no_sak = [ + param for param in factory_params if cs.SHARED_ACCESS_KEY not in param.values[0] + ] + + @pytest.mark.it("Returns a new IoTHubSession instance") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_instantiation(self, connection_string, ssl_context): + session = IoTHubSession.from_connection_string(connection_string, ssl_context) + assert isinstance(session, IoTHubSession) + + @pytest.mark.it( + "Extracts the `DeviceId`, `ModuleId` and `SharedAccessKey` values from the connection string (if present), passing them to the IoTHubSession initializer" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_extracts_values(self, mocker, connection_string, ssl_context): + spy_session_init = mocker.spy(IoTHubSession, "__init__") + cs_obj = cs.ConnectionString(connection_string) + + IoTHubSession.from_connection_string(connection_string, ssl_context) + + assert spy_session_init.call_count == 1 + assert spy_session_init.call_args[1]["device_id"] == cs_obj[cs.DEVICE_ID] + assert spy_session_init.call_args[1]["module_id"] == cs_obj.get(cs.MODULE_ID) + assert spy_session_init.call_args[1]["shared_access_key"] == cs_obj.get( + cs.SHARED_ACCESS_KEY + ) + + @pytest.mark.it( + "Extracts the `HostName` value from the connection string and passes it to the IoTHubSession initializer as the `hostname`, if no `GatewayHostName` value is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_gateway) + async def test_hostname(self, mocker, connection_string, ssl_context): + spy_session_init = mocker.spy(IoTHubSession, "__init__") + cs_obj = cs.ConnectionString(connection_string) + + IoTHubSession.from_connection_string(connection_string, ssl_context) + + assert spy_session_init.call_count == 1 + assert spy_session_init.call_args[1]["hostname"] == cs_obj[cs.HOST_NAME] + + # TODO: This test is currently being skipped because test data does not include such values + # Get clarity on how we want to handle Edge, and then clear this up. + @pytest.mark.it( + "Extracts the `GatewayHostName` value from the connection string and passes it to the IoTHubSession initializer as the `hostname`, if present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_gateway) + async def test_gateway_hostname(self, mocker, connection_string, ssl_context): + spy_session_init = mocker.spy(IoTHubSession, "__init__") + cs_obj = cs.ConnectionString(connection_string) + assert cs_obj[cs.GATEWAY_HOST_NAME] != cs_obj[cs.HOST_NAME] + + IoTHubSession.from_connection_string(connection_string, ssl_context) + + assert spy_session_init.call_count == 1 + assert spy_session_init.call_args[1]["hostname"] == cs_obj[cs.GATEWAY_HOST_NAME] + + @pytest.mark.it("Passes any provided `ssl_context` to the IoTHubSession initializer") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_ssl_context(self, mocker, connection_string, ssl_context): + spy_session_init = mocker.spy(IoTHubSession, "__init__") + + IoTHubSession.from_connection_string(connection_string, ssl_context) + + assert spy_session_init.call_count == 1 + assert spy_session_init.call_args[1]["ssl_context"] is ssl_context + + @pytest.mark.it( + "Passes any provided optional keyword arguments to the IoTHubSession initializer" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs(self, mocker, connection_string, ssl_context, kwarg_name, kwarg_value): + spy_session_init = mocker.spy(IoTHubSession, "__init__") + kwargs = {kwarg_name: kwarg_value} + + IoTHubSession.from_connection_string(connection_string, ssl_context, **kwargs) + + assert spy_session_init.call_count == 1 + assert spy_session_init.call_args[1][kwarg_name] == kwarg_value + + @pytest.mark.it( + "Raises ValueError if `x509=true` is present in the connection string, but no `ssl_context` is provided" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_sak) + async def test_x509_true_no_ssl(self, connection_string, ssl_context): + # Ignore the ssl_context provided by the parametrization + with pytest.raises(ValueError): + IoTHubSession.from_connection_string(connection_string) + + @pytest.mark.it( + "Does not raise a ValueError if `x509=false` is present in the connection string and no `ssl_context` is provided" + ) + async def test_x509_equals_false(self): + # NOTE: This is a weird test in that if you aren't using X509 certs, there shouldn't be + # an `x509` field in your connection string in the first place. But, semantically, it feels + # as though this test ought to exist to validate that we are checking the value of the + # field, not just the key name. + # NOTE: Because we're in the land of undefined behavior here, on account of this scenario + # not being supposed to happen, I'm arbitrarily deciding we're testing this with a string + # containing a SharedAccessKey and no GatewayHostName for simplicity. + connection_string = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};x509=false".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + ) + IoTHubSession.from_connection_string(connection_string) + # If the above invocation didn't raise, the test passed, no assertions required + + @pytest.mark.it("Raises TypeError if an invalid keyword argument is provided") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_bad_kwarg(self, connection_string, ssl_context): + with pytest.raises(TypeError): + IoTHubSession.from_connection_string( + connection_string, ssl_context, invalid_argument="some_value" + ) + + @pytest.mark.it("Allows any exceptions raised while parsing the connection string to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_cs_parsing_raises(self, mocker, optional_ssl_context, exception): + # NOTE: This test covers all invalid connection string scenarios. For more detail, see the + # dedicated connection string parsing tests for the `connection_string.py` module - there's + # no reason to replicate them all here. + # NOTE: For the purposes of this test, it does not matter what this connection string is. + # The one provided here is valid, but the mock will cause the parsing to raise anyway. + connection_string = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + ) + # Mock cs parsing + mocker.patch.object(cs, "ConnectionString", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + IoTHubSession.from_connection_string( + connection_string, ssl_context=optional_ssl_context + ) + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised by the initializer to propagate") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_init_raises(self, mocker, connection_string, ssl_context, arbitrary_exception): + # NOTE: for an in-depth look at what possible exceptions could be raised, + # see the TestIoTHubSessionInstantiation suite. To prevent duplication, + # we will simply use an arbitrary exception here + mocker.patch.object(IoTHubSession, "__init__", side_effect=arbitrary_exception) + + with pytest.raises(type(arbitrary_exception)) as e_info: + IoTHubSession.from_connection_string(connection_string, ssl_context) + assert e_info.value is arbitrary_exception + + +@pytest.mark.describe("IoTHubSession -- Context Manager Usage") +class TestIoTHubSessionContextManager: + @pytest.fixture + def session(self, disconnected_session): + return disconnected_session + + @pytest.mark.it( + "Starts the IoTHubMQTTClient upon entry into the context manager, and stops it upon exit" + ) + async def test_mqtt_client_start_stop(self, session): + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.stop.await_count == 0 + + async with session as session: + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 0 + + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 1 + + @pytest.mark.it( + "Stops the IoTHubMQTTClient upon exit, even if an error was raised within the block inside the context manager" + ) + async def test_mqtt_client_start_stop_with_failure(self, session, arbitrary_exception): + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.stop.await_count == 0 + + try: + async with session as session: + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 1 + + @pytest.mark.it( + "Connect the IoTHubMQTTClient upon entry into the context manager, and disconnect it upon exit" + ) + async def test_mqtt_client_connection(self, session): + assert session._mqtt_client.connect.await_count == 0 + assert session._mqtt_client.disconnect.await_count == 0 + + async with session as session: + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 0 + + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 1 + + @pytest.mark.it( + "Disconnect the IoTHubMQTTClient upon exit, even if an error was raised within the block inside the context manager" + ) + async def test_mqtt_client_connection_with_failure(self, session, arbitrary_exception): + assert session._mqtt_client.connect.await_count == 0 + assert session._mqtt_client.disconnect.await_count == 0 + + try: + async with session as session: + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 1 + + @pytest.mark.it( + "Starts the SasTokenProvider upon entry into the context manager, and stops it upon exit, if one exists" + ) + async def test_sastoken_provider_start_stop(self, session, mock_sastoken_provider): + session._sastoken_provider = mock_sastoken_provider + assert session._sastoken_provider.start.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + async with session as session: + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 0 + + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + @pytest.mark.it( + "Stops the SasTokenProvider upon exit, even if an error was raised within the block inside the context manager" + ) + async def test_sastoken_provider_start_stop_with_failure( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + assert session._sastoken_provider.start.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + try: + async with session as session: + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + @pytest.mark.it("Can handle the case where SasTokenProvider does not exist") + async def test_no_sastoken_provider(self, session): + assert session._sastoken_provider is None + + async with session as session: + pass + + # If nothing raises here, the test passes + + @pytest.mark.it( + "Creates a Task from the MQTTClient's .wait_for_disconnect() coroutine method and stores it as the `wait_for_disconnect_task` attribute upon entry into the context manager, and cancels and clears the Task upon exit" + ) + async def test_wait_for_disconnect_task(self, mocker, session): + assert session._wait_for_disconnect_task is None + assert session._mqtt_client.wait_for_disconnect.call_count == 0 + + async with session as session: + # Task Created and Method called + assert isinstance(session._wait_for_disconnect_task, asyncio.Task) + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.call_args == mocker.call() + await asyncio.sleep(0.1) + assert session._mqtt_client.wait_for_disconnect.is_hanging() + # Returning method completes task (thus task corresponds to method) + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task.done() + assert ( + session._wait_for_disconnect_task.result() + is session._mqtt_client.wait_for_disconnect.return_value + ) + # Replace the task with a mock so we can show it is cancelled/cleared on exit + mock_task = mocker.MagicMock() + session._wait_for_disconnect_task = mock_task + assert mock_task.cancel.call_count == 0 + + # Mocked task was cancelled and cleared + assert mock_task.cancel.call_count == 1 + assert session._wait_for_disconnect_task is None + + @pytest.mark.it( + "Cancels and clears the `wait_for_disconnect_task` Task, even if an error was raised within the block inside the context manager" + ) + async def test_wait_for_disconnect_task_with_failure(self, session, arbitrary_exception): + assert session._wait_for_disconnect_task is None + + try: + async with session as session: + task = session._wait_for_disconnect_task + assert task is not None + assert not task.done() + raise arbitrary_exception + except type(arbitrary_exception): + pass + + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert task.done() + assert task.cancelled() + + @pytest.mark.it( + "Allows any errors raised within the block inside the context manager to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + # NOTE: it is important to test the CancelledError since it is a regular Exception in 3.7, + # but a BaseException from 3.8+ + pytest.param(asyncio.CancelledError(), id="CancelledError"), + ], + ) + async def test_error_propagation(self, session, exception): + with pytest.raises(type(exception)) as e_info: + async with session as session: + raise exception + assert e_info.value is exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while starting the SasTokenProvider during context manager entry to propagate" + ) + async def test_enter_sastoken_provider_start_raises( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.start.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Does not start or connect the IoTHubMQTTClient, nor create the `wait_for_disconnect_task`, if an error is raised while starting the SasTokenProvider during context manager entry" + ) + async def test_enter_sastoken_provider_start_raises_cleanup( + self, mocker, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.start.side_effect = arbitrary_exception + assert session._sastoken_provider.start.await_count == 0 + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.connect.await_count == 0 + assert session._wait_for_disconnect_task is None + spy_create_task = mocker.spy(asyncio, "create_task") + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + assert session._sastoken_provider.start.await_count == 1 + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.connect.await_count == 0 + assert session._wait_for_disconnect_task is None + assert spy_create_task.call_count == 0 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while starting the IoTHubMQTTClient during context manager entry to propagate" + ) + async def test_enter_mqtt_client_start_raises(self, session, arbitrary_exception): + session._mqtt_client.start.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the IoTHubMQTTClient and SasTokenProvider (if present) that were previously started, and does not create the `wait_for_disconnect_task`, if an error is raised while starting the IoTHubMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_enter_mqtt_client_start_raises_cleanup( + self, mocker, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.start.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + spy_create_task = mocker.spy(asyncio, "create_task") + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + assert session._wait_for_disconnect_task is None + assert spy_create_task.call_count == 0 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the SasTokenProvider (if present) even if an error was raised while stopping the IoTHubMQTTClient in response to an error raised while starting the IoTHubMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_enter_mqtt_client_start_raises_then_mqtt_client_stop_raises( + self, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.start.side_effect = arbitrary_exception + session._mqtt_client.stop.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the IoTHubMQTTClient even if an error was raised while stopping the SasTokenProvider in response to an error raised while starting the IoTHubMQTTClient during context manager entry" + ) + async def test_enter_mqtt_client_start_raises_then_sastoken_provider_stop_raises( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._mqtt_client.start.side_effect = arbitrary_exception + session._sastoken_provider.stop.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + @pytest.mark.it( + "Allows any errors raised while connecting with the IoTHubMQTTClient during context manager entry to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises(self, session, exception): + session._mqtt_client.connect.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + async with session as session: + pass + assert e_info.value is exception + + @pytest.mark.it( + "Stops the IoTHubMQTTClient and SasTokenProvider (if present) that were previously started, and does not create the `wait_for_disconnect_task`, if an error is raised while connecting during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises_cleanup( + self, mocker, session, sastoken_provider, exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.connect.side_effect = exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + spy_create_task = mocker.spy(asyncio, "create_task") + + with pytest.raises(type(exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + assert session._wait_for_disconnect_task is None + assert spy_create_task.call_count == 0 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the SasTokenProvider (if present) even if an error was raised while stopping the IoTHubMQTTClient in response to an error raised while connecting the IoTHubMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises_then_mqtt_client_stop_raises( + self, session, sastoken_provider, exception, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.connect.side_effect = exception # Realistic failure + session._mqtt_client.stop.side_effect = arbitrary_exception # Shouldn't happen + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + + # NOTE: arbitrary_exception is raised here instead of exception - this is because it + # happened second, during resolution of exception, thus taking precedence + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the IoTHubMQTTClient even if an error was raised while stopping the SasTokenProvider in response to an error raised while connecting the IoTHubMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises_then_sastoken_provider_stop_raises( + self, session, mock_sastoken_provider, exception, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._mqtt_client.connect.side_effect = exception # Realistic failure + session._sastoken_provider.stop.side_effect = arbitrary_exception # Shouldn't happen + assert session._mqtt_client.stop.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + # NOTE: arbitrary_exception is raised here instead of exception - this is because it + # happened second, during resolution of exception, thus taking precedence + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while disconnecting the IoTHubMQTTClient during context manager exit to propagate" + ) + async def test_exit_disconnect_raises(self, session, arbitrary_exception): + session._mqtt_client.disconnect.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the IoTHubMQTTClient and SasTokenProvider (if present), and cancels and clears the `wait_for_disconnect_task`, even if an error was raised while disconnecting the IoTHubMQTTClient during context manager exit" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_exit_disconnect_raises_cleanup( + self, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.disconnect.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + conn_drop_task = session._wait_for_disconnect_task + assert not conn_drop_task.done() + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert conn_drop_task.cancelled() + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while stopping the IoTHubMQTTClient during context manager exit to propagate" + ) + async def test_exit_mqtt_client_stop_raises(self, session, arbitrary_exception): + session._mqtt_client.stop.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Disconnects the IoTHubMQTTClient and stops the SasTokenProvider (if present), and cancels and clears the `wait_for_disconnect_task`, even if an error was raised while stopping the IoTHubMQTTClient during context manager exit" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_exit_mqtt_client_stop_raises_cleanup( + self, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.stop.side_effect = arbitrary_exception + assert session._mqtt_client.disconnect.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + conn_drop_task = session._wait_for_disconnect_task + assert not conn_drop_task.done() + + assert session._mqtt_client.disconnect.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert conn_drop_task.cancelled() + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while stopping the SasTokenProvider during context manager exit to propagate" + ) + async def test_exit_sastoken_provider_stop_raises( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.stop.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Disconnects and stops the IoTHubMQTTClient, and cancels and clears the `wait_for_disconnect_task`, even if an error was raised while stopping the SasTokenProvider during context manager exit" + ) + async def test_exit_sastoken_provider_stop_raises_cleanup( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.stop.side_effect = arbitrary_exception + assert session._mqtt_client.disconnect.await_count == 0 + assert session._mqtt_client.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + conn_drop_task = session._wait_for_disconnect_task + assert not conn_drop_task.done() + + assert session._mqtt_client.disconnect.await_count == 1 + assert session._mqtt_client.stop.await_count == 1 + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert conn_drop_task.cancelled() + + # TODO: consider adding detailed cancellation tests + # Not sure how cancellation would work in a context manager situation, needs more investigation + + +@pytest.mark.describe("IoTHubSession - .send_message()") +class TestIoTHubSessionSendMessage: + @pytest.mark.it( + "Invokes .send_message() on the IoTHubMQTTClient, passing the provided `message`, if `message` is a Message object" + ) + async def test_message_object(self, mocker, session): + assert session._mqtt_client.send_message.await_count == 0 + + m = models.Message("hi") + await session.send_message(m) + + assert session._mqtt_client.send_message.await_count == 1 + assert session._mqtt_client.send_message.await_args == mocker.call(m) + + @pytest.mark.it( + "Invokes .send_message() on the IoTHubMQTTClient, passing a new Message object with `message` as the payload, if `message` is a string" + ) + async def test_message_string(self, session): + assert session._mqtt_client.send_message.await_count == 0 + + m_str = "hi" + await session.send_message(m_str) + + assert session._mqtt_client.send_message.await_count == 1 + m_obj = session._mqtt_client.send_message.await_args[0][0] + assert isinstance(m_obj, models.Message) + assert m_obj.payload == m_str + + @pytest.mark.it("Allows any exceptions raised by the IoTHubMQTTClient to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTError(5), id="MQTTError"), + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Error"), + ], + ) + async def test_mqtt_client_raises(self, session, exception): + session._mqtt_client.send_message.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await session.send_message("hi") + assert e_info.value is exception + + @pytest.mark.it( + "Raises SessionError without invoking .send_message() on the IoTHubMQTTClient if it is not connected" + ) + async def test_not_connected(self, mocker, session): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + await session.send_message("hi") + assert session._mqtt_client.send_message.call_count == 0 + + @pytest.mark.it( + "Raises CancelledError if an expected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_expected_disconnect_during_send(self, session): + session._mqtt_client.send_message = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.send_message("hi")) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_message.wait_for_hang() + assert not t.done() + + # No disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it( + "Raises the MQTTConnectionDroppedError that caused the unexpected disconnect, if an unexpected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_unexpected_disconnect_during_send(self, session): + session._mqtt_client.send_message = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.send_message("hi")) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_message.wait_for_hang() + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(exc.MQTTConnectionDroppedError) as e_info: + await t + assert e_info.value is cause + + @pytest.mark.it("Can be cancelled while waiting for the IoTHubMQTTClient operation to complete") + async def test_cancel_during_send(self, session): + session._mqtt_client.send_message = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.send_message("hi")) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_message.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubSession - .send_direct_method_response()") +class TestIoTHubSessionSendDirectMethodResponse: + @pytest.fixture + def direct_method_response(self): + return models.DirectMethodResponse(request_id="id", status=200, payload={"some": "value"}) + + @pytest.mark.it( + "Invokes .send_direct_method_response() on the IoTHubMQTTClient, passing the provided `method_response`" + ) + async def test_invoke(self, mocker, session, direct_method_response): + assert session._mqtt_client.send_direct_method_response.await_count == 0 + + await session.send_direct_method_response(direct_method_response) + + assert session._mqtt_client.send_direct_method_response.await_count == 1 + assert session._mqtt_client.send_direct_method_response.await_args == mocker.call( + direct_method_response + ) + + @pytest.mark.it("Allows any exceptions raised by the IoTHubMQTTClient to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTError(5), id="MQTTError"), + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Error"), + ], + ) + async def test_mqtt_client_raises(self, session, direct_method_response, exception): + session._mqtt_client.send_direct_method_response.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await session.send_direct_method_response(direct_method_response) + assert e_info.value is exception + + @pytest.mark.it( + "Raises SessionError without invoking .send_direct_method_response() on the IoTHubMQTTClient if it is not connected" + ) + async def test_not_connected(self, mocker, session, direct_method_response): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + await session.send_direct_method_response(direct_method_response) + assert session._mqtt_client.send_message.call_count == 0 + + @pytest.mark.it( + "Raises CancelledError if an expected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_expected_disconnect_during_send(self, session, direct_method_response): + session._mqtt_client.send_direct_method_response = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.send_direct_method_response(direct_method_response)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_direct_method_response.wait_for_hang() + assert not t.done() + + # No disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it( + "Raises the MQTTConnectionDroppedError that caused the unexpected disconnect, if an unexpected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_unexpected_disconnect_during_send(self, session, direct_method_response): + session._mqtt_client.send_direct_method_response = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.send_direct_method_response(direct_method_response)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_direct_method_response.wait_for_hang() + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(exc.MQTTConnectionDroppedError) as e_info: + await t + assert e_info.value is cause + + @pytest.mark.it("Can be cancelled while waiting for the IoTHubMQTTClient operation to complete") + async def test_cancel_during_send(self, session, direct_method_response): + session._mqtt_client.send_direct_method_response = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.send_direct_method_response(direct_method_response)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_direct_method_response.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubSession - .update_reported_properties()") +class TestIoTHubSessionUpdateReportedProperties: + @pytest.fixture + def patch(self): + return {"key1": "value1", "key2": "value2"} + + @pytest.mark.it( + "Invokes .send_twin_patch() on the IoTHubMQTTClient, passing the provided `patch`" + ) + async def test_invoke(self, mocker, session, patch): + assert session._mqtt_client.send_twin_patch.await_count == 0 + + await session.update_reported_properties(patch) + + assert session._mqtt_client.send_twin_patch.await_count == 1 + assert session._mqtt_client.send_twin_patch.await_args == mocker.call(patch) + + @pytest.mark.it("Allows any exceptions raised by the IoTHubMQTTClient to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.IoTHubError(), id="IoTHubError"), + pytest.param(exc.MQTTError(5), id="MQTTError"), + pytest.param(ValueError(), id="ValueError"), + pytest.param(asyncio.CancelledError(), id="CancelledError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Error"), + ], + ) + async def test_mqtt_client_raises(self, session, patch, exception): + session._mqtt_client.send_twin_patch.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await session.update_reported_properties(patch) + # CancelledError doesn't propagate in some versions of Python + # TODO: determine which versions exactly + if not isinstance(exception, asyncio.CancelledError): + assert e_info.value is exception + + @pytest.mark.it( + "Raises SessionError without invoking .send_twin_patch() on the IoTHubMQTTClient if it is not connected" + ) + async def test_not_connected(self, mocker, session, patch): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + await session.update_reported_properties(patch) + assert session._mqtt_client.send_twin_patch.call_count == 0 + + @pytest.mark.it( + "Raises CancelledError if an expected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_expected_disconnect_during_send(self, session, patch): + session._mqtt_client.send_twin_patch = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.update_reported_properties(patch)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_twin_patch.wait_for_hang() + assert not t.done() + + # No disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it( + "Raises the MQTTConnectionDroppedError that caused the unexpected disconnect, if an unexpected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_unexpected_disconnect_during_send(self, session, patch): + session._mqtt_client.send_twin_patch = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.update_reported_properties(patch)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_twin_patch.wait_for_hang() + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(exc.MQTTConnectionDroppedError) as e_info: + await t + assert e_info.value is cause + + @pytest.mark.it("Can be cancelled while waiting for the IoTHubMQTTClient operation to complete") + async def test_cancel_during_send(self, session, patch): + session._mqtt_client.send_twin_patch = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.update_reported_properties(patch)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_twin_patch.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubSession - .get_twin()") +class TestIoTHubSessionGetTwin: + @pytest.mark.it("Invokes .get_twin() on the IoTHubMQTTClient") + async def test_invoke(self, mocker, session): + assert session._mqtt_client.get_twin.await_count == 0 + + await session.get_twin() + + assert session._mqtt_client.get_twin.await_count == 1 + assert session._mqtt_client.get_twin.await_args == mocker.call() + + @pytest.mark.it("Allows any exceptions raised by the IoTHubMQTTClient to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.IoTHubError(), id="IoTHubError"), + pytest.param(exc.MQTTError(5), id="MQTTError"), + pytest.param(asyncio.CancelledError(), id="CancelledError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Error"), + ], + ) + async def test_mqtt_client_raises(self, session, exception): + session._mqtt_client.get_twin.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await session.get_twin() + # CancelledError doesn't propagate in some versions of Python + # TODO: determine which versions exactly + if not isinstance(exception, asyncio.CancelledError): + assert e_info.value is exception + + @pytest.mark.it( + "Raises SessionError without invoking .get_twin() on the IoTHubMQTTClient if it is not connected" + ) + async def test_not_connected(self, mocker, session): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + await session.get_twin() + assert session._mqtt_client.get_twin.call_count == 0 + + @pytest.mark.it( + "Raises CancelledError if an expected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_expected_disconnect_during_send(self, session): + session._mqtt_client.get_twin = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.get_twin()) + + # Hanging, waiting for send to finish + await session._mqtt_client.get_twin.wait_for_hang() + assert not t.done() + + # No disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it( + "Raises the MQTTConnectionDroppedError that caused the unexpected disconnect, if an unexpected disconnect occurs in the IoTHubMQTTClient while waiting for the operation to complete" + ) + async def test_unexpected_disconnect_during_send(self, session): + session._mqtt_client.get_twin = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.get_twin()) + + # Hanging, waiting for send to finish + await session._mqtt_client.get_twin.wait_for_hang() + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(exc.MQTTConnectionDroppedError) as e_info: + await t + assert e_info.value is cause + + @pytest.mark.it("Can be cancelled while waiting for the IoTHubMQTTClient operation to complete") + async def test_cancel_during_send(self, session): + session._mqtt_client.get_twin = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.get_twin()) + + # Hanging, waiting for send to finish + await session._mqtt_client.get_twin.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubSession - .messages()") +class TestIoTHubSessionMessages: + @pytest.mark.it( + "Enables C2D message receive with the IoTHubMQTTClient upon entry into the context manager and disables C2D message receive upon exit" + ) + async def test_context_manager(self, session): + assert session._mqtt_client.enable_c2d_message_receive.await_count == 0 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 0 + + async with session.messages(): + assert session._mqtt_client.enable_c2d_message_receive.await_count == 1 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 0 + + assert session._mqtt_client.enable_c2d_message_receive.await_count == 1 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 1 + + @pytest.mark.it( + "Disables C2D message receive upon exit, even if an error is raised inside the context manager block" + ) + async def test_context_manager_failure(self, session, arbitrary_exception): + assert session._mqtt_client.enable_c2d_message_receive.await_count == 0 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 0 + + try: + async with session.messages(): + assert session._mqtt_client.enable_c2d_message_receive.await_count == 1 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.enable_c2d_message_receive.await_count == 1 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 1 + + @pytest.mark.it( + "Does not attempt to disable C2D message receive upon exit if IoTHubMQTTClient is disconnected" + ) + @pytest.mark.parametrize("graceful_exit", graceful_exit_params) + async def test_context_manager_exit_while_disconnected( + self, session, arbitrary_exception, graceful_exit + ): + assert session._mqtt_client.enable_c2d_message_receive.await_count == 0 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 0 + + try: + async with session.messages(): + assert session._mqtt_client.enable_c2d_message_receive.await_count == 1 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 0 + session._mqtt_client.connected = False + if not graceful_exit: + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.enable_c2d_message_receive.await_count == 1 + assert session._mqtt_client.disable_c2d_message_receive.await_count == 0 + + @pytest.mark.it( + "Yields an AsyncGenerator that yields the C2D messages yielded by the IoTHubMQTTClient's incoming C2D message generator" + ) + async def test_generator_yield(self, mocker, session): + # Mock IoTHubMQTTClient C2D generator to yield Messages + yielded_c2d_messages = [models.Message("1"), models.Message("2"), models.Message("3")] + mock_c2d_gen = mocker.AsyncMock() + mock_c2d_gen.__anext__.side_effect = yielded_c2d_messages + # Set it to be returned by PropertyMock + c2d_gen_property_mock = mocker.PropertyMock(return_value=mock_c2d_gen) + type(session._mqtt_client).incoming_c2d_messages = c2d_gen_property_mock + + assert not session._wait_for_disconnect_task.done() + async with session.messages() as messages: + # Is a generator + assert isinstance(messages, typing.AsyncGenerator) + # Yields values from the IoTHubMQTTClient C2D generator + assert mock_c2d_gen.__anext__.await_count == 0 + val = await messages.__anext__() + assert val is yielded_c2d_messages[0] + assert mock_c2d_gen.__anext__.await_count == 1 + val = await messages.__anext__() + assert val is yielded_c2d_messages[1] + assert mock_c2d_gen.__anext__.await_count == 2 + val = await messages.__anext__() + assert val is yielded_c2d_messages[2] + assert mock_c2d_gen.__anext__.await_count == 3 + + @pytest.mark.it( + "Yields an AsyncGenerator that will raise the MQTTConnectionDroppedError that caused an unexpected disconnect in the IoTHubMQTTClient in the event of an unexpected disconnection" + ) + async def test_generator_raise_unexpected_disconnect(self, mocker, session): + # Mock IoTHubMQTTClient C2D generator to not yield anything yet + mock_c2d_gen = mocker.AsyncMock() + mock_c2d_gen.__anext__ = custom_mock.HangingAsyncMock() + # Set it to be returned by PropertyMock + c2d_gen_property_mock = mocker.PropertyMock(return_value=mock_c2d_gen) + type(session._mqtt_client).incoming_c2d_messages = c2d_gen_property_mock + + async with session.messages() as messages: + # Waiting for new item from generator (since mock is hanging / not returning) + t = asyncio.create_task(messages.__anext__()) + await asyncio.sleep(0.1) + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Trigger unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + + # Generator raised the error that caused disconnect + assert t.done() + assert t.exception() is cause + + @pytest.mark.it( + "Yields an AsyncGenerator that will raise a CancelledError if the event of an expected disconnection" + ) + async def test_generator_raise_expected_disconnect(self, mocker, session): + # Mock IoTHubMQTTClient C2D generator to not yield anything yet + mock_c2d_gen = mocker.AsyncMock() + mock_c2d_gen.__anext__ = custom_mock.HangingAsyncMock() + # Set it to be returned by PropertyMock + c2d_gen_property_mock = mocker.PropertyMock(return_value=mock_c2d_gen) + type(session._mqtt_client).incoming_c2d_messages = c2d_gen_property_mock + + async with session.messages() as messages: + # Waiting for new item from generator (since mock is hanging / not returning) + t = asyncio.create_task(messages.__anext__()) + await asyncio.sleep(0.1) + assert not t.done() + + # No disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Trigger expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + + # Generator raised CancelledError + assert t.done() + assert t.cancelled() + + @pytest.mark.it( + "Raises SessionError without enabling C2D message receive on the IoTHubMQTTClient if it is not connected" + ) + async def test_not_connected(self, mocker, session): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + async with session.messages(): + pass + assert session._mqtt_client.enable_c2d_message_receive.call_count == 0 + + @pytest.mark.it( + "Allows any errors raised while attempting to enable C2D message receive to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTError(rc=4), id="MQTTError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enable_raises(self, session, exception): + session._mqtt_client.enable_c2d_message_receive.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + async with session.messages(): + pass + assert e_info.value is exception + assert session._mqtt_client.enable_c2d_message_receive.call_count == 1 + + @pytest.mark.it( + "Suppresses any MQTTErrors raised while attempting to disable C2D message receive" + ) + async def test_disable_raises_mqtt_error(self, session): + session._mqtt_client.disable_c2d_message_receive.side_effect = exc.MQTTError(rc=4) + + async with session.messages(): + pass + assert session._mqtt_client.disable_c2d_message_receive.call_count == 1 + # No error raised -> success + + @pytest.mark.it( + "Allows any unexpected errors raised while attempting to disable C2D message receive to propagate" + ) + async def test_disable_raises_unexpected(self, session, arbitrary_exception): + session._mqtt_client.disable_c2d_message_receive.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session.messages(): + pass + assert e_info.value is arbitrary_exception + assert session._mqtt_client.disable_c2d_message_receive.call_count == 1 + + +@pytest.mark.describe("IoTHubSession - .direct_method_requests()") +class TestIoTHubSessionDirectMethodRequests: + @pytest.mark.it( + "Enables direct method request receive with the IoTHubMQTTClient upon entry into the context manager and disables direct method request receive upon exit" + ) + async def test_context_manager(self, session): + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 0 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0 + + async with session.direct_method_requests(): + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0 + + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 1 + + @pytest.mark.it( + "Disables direct method request receive upon exit, even if an error is raised inside the context manager block" + ) + async def test_context_manager_failure(self, session, arbitrary_exception): + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 0 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0 + + try: + async with session.direct_method_requests(): + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 1 + + @pytest.mark.it( + "Does not attempt to disable direct method request receive upon exit if IoTHubMQTTClient is disconnected" + ) + @pytest.mark.parametrize("graceful_exit", graceful_exit_params) + async def test_context_manager_exit_while_disconnected( + self, session, arbitrary_exception, graceful_exit + ): + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 0 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0 + + try: + async with session.direct_method_requests(): + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0 + session._mqtt_client.connected = False + if not graceful_exit: + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1 + assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0 + + @pytest.mark.it( + "Yields an AsyncGenerator that yields the direct method requests yielded by the IoTHubMQTTClient's incoming direct method request message generator" + ) + async def test_generator_yield(self, mocker, session): + # Mock IoTHubMQTTClient direct method request generator to yield DirectMethodRequests + yielded_direct_method_requests = [ + models.DirectMethodRequest("1", "m1", ""), + models.DirectMethodRequest("2", "m2", ""), + models.DirectMethodRequest("3", "m3", ""), + ] + mock_dm_gen = mocker.AsyncMock() + mock_dm_gen.__anext__.side_effect = yielded_direct_method_requests + # Set it to be returned by PropertyMock + dm_gen_property_mock = mocker.PropertyMock(return_value=mock_dm_gen) + type(session._mqtt_client).incoming_direct_method_requests = dm_gen_property_mock + + assert not session._wait_for_disconnect_task.done() + async with session.direct_method_requests() as direct_method_requests: + # Is a generator + assert isinstance(direct_method_requests, typing.AsyncGenerator) + # Yields values from the IoTHubMQTTClient direct method request generator + assert mock_dm_gen.__anext__.await_count == 0 + val = await direct_method_requests.__anext__() + assert val is yielded_direct_method_requests[0] + assert mock_dm_gen.__anext__.await_count == 1 + val = await direct_method_requests.__anext__() + assert val is yielded_direct_method_requests[1] + assert mock_dm_gen.__anext__.await_count == 2 + val = await direct_method_requests.__anext__() + assert val is yielded_direct_method_requests[2] + assert mock_dm_gen.__anext__.await_count == 3 + + @pytest.mark.it( + "Yields an AsyncGenerator that will raise the MQTTConnectionDroppedError that caused an unexpected disconnect in the IoTHubMQTTClient in the event of an unexpected disconnection" + ) + async def test_generator_raise_unexpected_disconnect(self, mocker, session): + # Mock IoTHubMQTTClient direct method request generator to not yield anything yet + mock_dm_gen = mocker.AsyncMock() + mock_dm_gen.__anext__ = custom_mock.HangingAsyncMock() + # Set it to be returned by PropertyMock + dm_gen_property_mock = mocker.PropertyMock(return_value=mock_dm_gen) + type(session._mqtt_client).incoming_direct_method_requests = dm_gen_property_mock + + async with session.direct_method_requests() as direct_method_requests: + # Waiting for new item from generator (since mock is hanging / not returning) + t = asyncio.create_task(direct_method_requests.__anext__()) + await asyncio.sleep(0.1) + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Trigger unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + + # Generator raised the error that caused disconnect + assert t.done() + assert t.exception() is cause + + @pytest.mark.it( + "Yields an AsyncGenerator that will raise a CancelledError if the event of an expected disconnection" + ) + async def test_generator_raise_expected_disconnect(self, mocker, session): + # Mock IoTHubMQTTClient direct method request generator to not yield anything yet + mock_dm_gen = mocker.AsyncMock() + mock_dm_gen.__anext__ = custom_mock.HangingAsyncMock() + # Set it to be returned by PropertyMock + dm_gen_property_mock = mocker.PropertyMock(return_value=mock_dm_gen) + type(session._mqtt_client).incoming_direct_method_requests = dm_gen_property_mock + + async with session.direct_method_requests() as direct_method_requests: + # Waiting for new item from generator (since mock is hanging / not returning) + t = asyncio.create_task(direct_method_requests.__anext__()) + await asyncio.sleep(0.1) + assert not t.done() + + # No disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Trigger expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + + # Generator raised CancelledError + assert t.done() + assert t.cancelled() + + @pytest.mark.it( + "Raises SessionError without enabling direct method request receive on the IoTHubMQTTClient if it is not connected" + ) + async def test_not_connected(self, mocker, session): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + async with session.direct_method_requests(): + pass + assert session._mqtt_client.enable_direct_method_request_receive.call_count == 0 + + @pytest.mark.it( + "Allows any errors raised while attempting to enable direct method request receive to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTError(rc=4), id="MQTTError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enable_raises(self, session, exception): + session._mqtt_client.enable_direct_method_request_receive.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + async with session.direct_method_requests(): + pass + assert e_info.value is exception + assert session._mqtt_client.enable_direct_method_request_receive.call_count == 1 + + @pytest.mark.it( + "Suppresses any MQTTErrors raised while attempting to disable direct method request receive" + ) + async def test_disable_raises_mqtt_error(self, session): + session._mqtt_client.disable_direct_method_request_receive.side_effect = exc.MQTTError(rc=4) + + async with session.direct_method_requests(): + pass + assert session._mqtt_client.disable_direct_method_request_receive.call_count == 1 + # No error raised -> success + + @pytest.mark.it( + "Allows any unexpected errors raised while attempting to disable direct method request receive to propagate" + ) + async def test_disable_raises_unexpected(self, session, arbitrary_exception): + session._mqtt_client.disable_direct_method_request_receive.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session.direct_method_requests(): + pass + assert e_info.value is arbitrary_exception + assert session._mqtt_client.disable_direct_method_request_receive.call_count == 1 + + +@pytest.mark.describe("IoTHubSession - .desired_property_updates()") +class TestIoTHubSessionDesiredPropertyUpdates: + @pytest.mark.it( + "Enables twin patch receive with the IoTHubMQTTClient upon entry into the context manager and disables twin patch receive upon exit" + ) + async def test_context_manager(self, session): + assert session._mqtt_client.enable_twin_patch_receive.await_count == 0 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 0 + + async with session.desired_property_updates(): + assert session._mqtt_client.enable_twin_patch_receive.await_count == 1 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 0 + + assert session._mqtt_client.enable_twin_patch_receive.await_count == 1 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 1 + + @pytest.mark.it( + "Disables twin patch receive upon exit, even if an error is raised inside the context manager block" + ) + async def test_context_manager_failure(self, session, arbitrary_exception): + assert session._mqtt_client.enable_twin_patch_receive.await_count == 0 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 0 + + try: + async with session.desired_property_updates(): + assert session._mqtt_client.enable_twin_patch_receive.await_count == 1 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.enable_twin_patch_receive.await_count == 1 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 1 + + @pytest.mark.it( + "Does not attempt to disable twin patch receive upon exit if IoTHubMQTTClient is disconnected" + ) + @pytest.mark.parametrize("graceful_exit", graceful_exit_params) + async def test_context_manager_exit_while_disconnected( + self, session, arbitrary_exception, graceful_exit + ): + assert session._mqtt_client.enable_twin_patch_receive.await_count == 0 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 0 + + try: + async with session.desired_property_updates(): + assert session._mqtt_client.enable_twin_patch_receive.await_count == 1 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 0 + session._mqtt_client.connected = False + if not graceful_exit: + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.enable_twin_patch_receive.await_count == 1 + assert session._mqtt_client.disable_twin_patch_receive.await_count == 0 + + @pytest.mark.it( + "Yields an AsyncGenerator that yields the desired property patches yielded by the IoTHubMQTTClient's incoming twin patch generator" + ) + async def test_generator_yield(self, mocker, session): + # Mock IoTHubMQTTClient twin patch generator to yield twin patches + yielded_twin_patches = [{"1": 1}, {"2": 2}, {"3": 3}] + mock_twin_patch_gen = mocker.AsyncMock() + mock_twin_patch_gen.__anext__.side_effect = yielded_twin_patches + # Set it to be returned by PropertyMock + twin_patch_property_mock = mocker.PropertyMock(return_value=mock_twin_patch_gen) + type(session._mqtt_client).incoming_twin_patches = twin_patch_property_mock + + assert not session._wait_for_disconnect_task.done() + async with session.desired_property_updates() as desired_property_updates: + # Is a generator + assert isinstance(desired_property_updates, typing.AsyncGenerator) + # Yields values from the IoTHubMQTTClient C2D generator + assert mock_twin_patch_gen.__anext__.await_count == 0 + val = await desired_property_updates.__anext__() + assert val is yielded_twin_patches[0] + assert mock_twin_patch_gen.__anext__.await_count == 1 + val = await desired_property_updates.__anext__() + assert val is yielded_twin_patches[1] + assert mock_twin_patch_gen.__anext__.await_count == 2 + val = await desired_property_updates.__anext__() + assert val is yielded_twin_patches[2] + assert mock_twin_patch_gen.__anext__.await_count == 3 + + @pytest.mark.it( + "Yields an AsyncGenerator that will raise the MQTTConnectionDroppedError that caused an unexpected disconnect in the IoTHubMQTTClient in the event of an unexpected disconnection" + ) + async def test_generator_raise_unexpected_disconnect(self, mocker, session): + # Mock IoTHubMQTTClient twin patch generator to not yield anything yet + mock_twin_patch_gen = mocker.AsyncMock() + mock_twin_patch_gen.__anext__ = custom_mock.HangingAsyncMock() + # Set it to be returned by PropertyMock + twin_patch_property_mock = mocker.PropertyMock(return_value=mock_twin_patch_gen) + type(session._mqtt_client).incoming_twin_patches = twin_patch_property_mock + + async with session.desired_property_updates() as desired_property_updates: + # Waiting for new item from generator (since mock is hanging / not returning) + t = asyncio.create_task(desired_property_updates.__anext__()) + await asyncio.sleep(0.1) + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Trigger unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + + # Generator raised the error that caused disconnect + assert t.done() + assert t.exception() is cause + + @pytest.mark.it( + "Yields an AsyncGenerator that will raise a CancelledError if the event of an expected disconnection" + ) + async def test_generator_raise_expected_disconnect(self, mocker, session): + # Mock IoTHubMQTTClient twin patch generator to not yield anything yet + mock_twin_patch_gen = mocker.AsyncMock() + mock_twin_patch_gen.__anext__ = custom_mock.HangingAsyncMock() + # Set it to be returned by PropertyMock + twin_patch_property_mock = mocker.PropertyMock(return_value=mock_twin_patch_gen) + type(session._mqtt_client).incoming_twin_patches = twin_patch_property_mock + + async with session.desired_property_updates() as desired_property_updates: + # Waiting for new item from generator (since mock is hanging / not returning) + t = asyncio.create_task(desired_property_updates.__anext__()) + await asyncio.sleep(0.1) + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Trigger expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + + # Generator raised CancelledError + assert t.done() + assert t.cancelled() + + @pytest.mark.it( + "Raises SessionError without enabling twin patch receive on the IoTHubMQTTClient if it is not connected" + ) + async def test_not_connected(self, mocker, session): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + async with session.desired_property_updates(): + pass + assert session._mqtt_client.enable_twin_patch_receive.call_count == 0 + + @pytest.mark.it( + "Allows any errors raised while attempting to enable twin patch receive to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTError(rc=4), id="MQTTError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enable_raises(self, session, exception): + session._mqtt_client.enable_twin_patch_receive.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + async with session.desired_property_updates(): + pass + assert e_info.value is exception + assert session._mqtt_client.enable_twin_patch_receive.call_count == 1 + + @pytest.mark.it( + "Suppresses any MQTTErrors raised while attempting to disable twin patch receive" + ) + async def test_disable_raises_mqtt_error(self, session): + session._mqtt_client.disable_twin_patch_receive.side_effect = exc.MQTTError(rc=4) + + async with session.desired_property_updates(): + pass + assert session._mqtt_client.disable_twin_patch_receive.call_count == 1 + # No error raised -> success + + @pytest.mark.it( + "Allows any unexpected errors raised while attempting to disable twin patch receive to propagate" + ) + async def test_disable_raises_unexpected(self, session, arbitrary_exception): + session._mqtt_client.disable_twin_patch_receive.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session.desired_property_updates(): + pass + assert e_info.value is arbitrary_exception + assert session._mqtt_client.disable_twin_patch_receive.call_count == 1 diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py new file mode 100644 index 000000000..3acf2278d --- /dev/null +++ b/tests/unit/test_models.py @@ -0,0 +1,281 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import logging +from azure.iot.device.models import Message, DirectMethodRequest, DirectMethodResponse +from azure.iot.device import constant + +logging.basicConfig(level=logging.DEBUG) + +FAKE_RID = "123" +FAKE_METHOD_NAME = "some_method" +FAKE_STATUS = 200 + + +json_serializable_payload_params = [ + pytest.param("String payload", id="String Payload"), + pytest.param(123, id="Integer Payload"), + pytest.param(2.0, id="Float Payload"), + pytest.param(True, id="Boolean Payload"), + pytest.param({"dictionary": {"payload": "nested"}}, id="Dictionary Payload"), + pytest.param([1, 2, 3], id="List Payload"), + pytest.param((1, 2, 3), id="Tuple Payload"), + pytest.param(None, id="No Payload"), +] + + +@pytest.mark.describe("Message") +class TestMessage: + @pytest.mark.it("Instantiates with the provided payload set as an attribute") + @pytest.mark.parametrize("payload", json_serializable_payload_params) + def test_instantiates_from_data(self, payload): + msg = Message(payload) + assert msg.payload == payload + + @pytest.mark.it( + "Instantiates with optional provided content type and content encoding set as attributes" + ) + @pytest.mark.parametrize("content_type", ["text/plain", "application/json"]) + @pytest.mark.parametrize("content_encoding", ["utf-8", "utf-16", "utf-32"]) + def test_instantiates_with_optional_contenttype_encoding(self, content_type, content_encoding): + msg = Message("some message", content_encoding, content_type) + assert msg.content_encoding == content_encoding + assert msg.content_type == content_type + + @pytest.mark.it("Defaults content encoding to 'utf-8' if not provided") + def test_default_content_encoding(self): + msg = Message("some message") + assert msg.content_encoding == "utf-8" + + @pytest.mark.it("Raises ValueError if unsupported content encoding provided") + def test_unsupported_content_encoding(self): + with pytest.raises(ValueError): + Message("some message", content_encoding="ascii") + + @pytest.mark.it("Defaults content type to 'text/plain' if not provided") + def test_default_content_type(self): + msg = Message("some message") + assert msg.content_type == "text/plain" + + @pytest.mark.it("Raises ValueError if unsupported content type provided") + def test_unsupported_content_type(self): + with pytest.raises(ValueError): + Message("some message", content_type="text/javascript") + + @pytest.mark.it("Instantiates with optional provided output name set as an attribute") + def test_instantiates_with_optional_output_name(self): + output_name = "some_output" + msg = Message("some message", output_name=output_name) + assert msg.output_name == output_name + + @pytest.mark.it("Instantiates with no message id set") + def test_default_message_id(self): + msg = Message("some message") + assert msg.message_id is None + + @pytest.mark.it("Instantiates with no custom properties set") + def test_default_custom_properties(self): + msg = Message("some message") + assert msg.custom_properties == {} + + @pytest.mark.it("Instantiates with no set input name") + def test_default_input_name(self): + msg = Message("some message") + assert msg.input_name is None + + @pytest.mark.it("Instantiates with no set ack value") + def test_default_ack(self): + msg = Message("some message") + assert msg.ack is None + + @pytest.mark.it("Instantiates with no set expiry time") + def test_default_expiry_time(self): + msg = Message("some message") + assert msg.expiry_time_utc is None + + @pytest.mark.it("Instantiates with no set user id") + def test_default_user_id(self): + msg = Message("some message") + assert msg.user_id is None + + @pytest.mark.it("Instantiates with no set correlation id") + def test_default_corr_id(self): + msg = Message("some message") + assert msg.correlation_id is None + + @pytest.mark.it("Instantiates with no set iothub_interface_id (i.e. not as a security message)") + def test_default_security_msg_status(self): + msg = Message("some message") + assert msg.iothub_interface_id is None + + @pytest.mark.it("Maintains iothub_interface_id (security message) as a read-only property") + def test_read_only_iothub_interface_id(self): + msg = Message("some message") + with pytest.raises(AttributeError): + msg.iothub_interface_id = "value" + + @pytest.mark.it( + "Uses string representation of data/payload attribute as string representation of Message" + ) + @pytest.mark.parametrize("payload", json_serializable_payload_params) + def test_str_rep(self, payload): + msg = Message(payload) + assert str(msg) == str(payload) + + @pytest.mark.it("Can be set as a security message via API") + def test_setting_message_as_security_message(self): + ctype = "application/json" + encoding = "utf-16" + msg = Message("some message", encoding, ctype) + assert msg.iothub_interface_id is None + msg.set_as_security_message() + assert msg.iothub_interface_id == constant.SECURITY_MESSAGE_INTERFACE_ID + + # NOTE: This test tests all system properties, even though they shouldn't all be present simultaneously + @pytest.mark.it("Can return the system properties set on the Message as a dictionary via API") + def test_system_properties_dict_all(self): + msg = Message("some message") + msg.message_id = "message id" + msg.content_encoding = "application/json" + msg.content_type = "utf-16" + msg.output_name = "output name" + msg._iothub_interface_id = "interface id" + msg.input_name = "input name" + msg.ack = "value" + msg.expiry_time_utc = "time" + msg.user_id = "user id" + msg.correlation_id = "correlation id" + sys_prop = msg.get_system_properties_dict() + + assert sys_prop["$.mid"] == msg.message_id + assert sys_prop["$.ce"] == msg.content_encoding + assert sys_prop["$.ct"] == msg.content_type + assert sys_prop["$.on"] == msg.output_name + assert sys_prop["$.ifid"] == msg._iothub_interface_id + assert sys_prop["$.to"] == msg.input_name + assert sys_prop["iothub-ack"] == msg.ack + assert sys_prop["$.exp"] == msg.expiry_time_utc + assert sys_prop["$.uid"] == msg.user_id + assert sys_prop["$.cid"] == msg.correlation_id + + @pytest.mark.it( + "Only contains the system properties present on the Message in the system properties dictionary" + ) + def test_system_properties_dict_partial(self): + msg = Message("some message") + msg.message_id = "message id" + assert msg.content_encoding is not None + assert msg.content_type is not None + + sys_prop = msg.get_system_properties_dict() + assert len(sys_prop) == 3 + assert sys_prop["$.mid"] == msg.message_id + assert sys_prop["$.ce"] == msg.content_encoding + assert sys_prop["$.ct"] == msg.content_type + + # NOTE: This test tests all system properties, even though they shouldn't all be present simultaneously + @pytest.mark.it("Can be instantiated from a properties dictionary") + @pytest.mark.parametrize( + "custom_properties", + [ + pytest.param({}, id="System Properties Only"), + pytest.param( + {"cust1": "v1", "cust2": "v2"}, id="System Properties and Custom Properties" + ), + ], + ) + def test_create_from_dict(self, custom_properties): + system_properties = { + "$.mid": "message id", + "$.ce": "application/json", + "$.ct": "utf-16", + "$.on": "output name", + "$.ifid": "interface id", + "$.to": "input name", + "iothub-ack": "value", + "$.exp": "time", + "$.uid": "user id", + "$.cid": "correlation id", + } + properties = dict(system_properties) + properties.update(custom_properties) + message = Message.create_from_properties_dict("some payload", properties) + + assert message.message_id == system_properties["$.mid"] + assert message.content_encoding == system_properties["$.ce"] + assert message.content_type == system_properties["$.ct"] + assert message.output_name == system_properties["$.on"] + assert message._iothub_interface_id == system_properties["$.ifid"] + assert message.input_name == system_properties["$.to"] + assert message.ack == system_properties["iothub-ack"] + assert message.expiry_time_utc == system_properties["$.exp"] + assert message.user_id == system_properties["$.uid"] + assert message.correlation_id == system_properties["$.cid"] + + for key in custom_properties: + assert message.custom_properties[key] == custom_properties[key] + + @pytest.mark.it( + "Uses default values for system properties when creating from a properties dictionary if they are not in the properties dictionary" + ) + def test_create_from_dict_defaults(self): + properties = { + "$.mid": "message id", + } + message = Message.create_from_properties_dict("some payload", properties) + assert message.content_encoding == "utf-8" + assert message.content_type == "text/plain" + + +@pytest.mark.describe("DirectMethodRequest") +class TestDirectMethodRequest: + @pytest.mark.it("Instantiates with the provided 'request_id' set as an attribute") + def test_request_id(self): + m_req = DirectMethodRequest(request_id=FAKE_RID, name=FAKE_METHOD_NAME, payload={}) + assert m_req.request_id == FAKE_RID + + @pytest.mark.it("Instantiates with the provided 'name' set as an attribute") + def test_name(self): + m_req = DirectMethodRequest(request_id=FAKE_RID, name=FAKE_METHOD_NAME, payload={}) + assert m_req.name == FAKE_METHOD_NAME + + @pytest.mark.it("Instantiates with the provided 'payload' set as an attribute") + @pytest.mark.parametrize("payload", json_serializable_payload_params) + def test_payload(self, payload): + m_req = DirectMethodRequest(request_id=FAKE_RID, name=FAKE_METHOD_NAME, payload=payload) + assert m_req.payload == payload + + +@pytest.mark.describe("DirectMethodResponse") +class TestDirectMethodResponse: + @pytest.mark.it("Instantiates with the provided 'request_id' set as an attribute") + def test_request_id(self): + m_resp = DirectMethodResponse(request_id=FAKE_RID, status=FAKE_STATUS, payload={}) + assert m_resp.request_id == FAKE_RID + + @pytest.mark.it("Instantiates with the provided 'status' set as an attribute") + def test_status(self): + m_resp = DirectMethodResponse(request_id=FAKE_RID, status=FAKE_STATUS, payload={}) + assert m_resp.status == FAKE_STATUS + + @pytest.mark.it("Instantiates with the optional provided 'payload' set as an attribute") + @pytest.mark.parametrize("payload", json_serializable_payload_params) + def test_payload(self, payload): + m_resp = DirectMethodResponse(request_id=FAKE_RID, status=FAKE_STATUS, payload=payload) + assert m_resp.payload == payload + + @pytest.mark.it("Can be instantiated from a DirectMethodResponse via factory API") + @pytest.mark.parametrize("payload", json_serializable_payload_params) + def test_factory(self, payload): + m_req = DirectMethodRequest(request_id=FAKE_RID, name=FAKE_METHOD_NAME, payload={}) + m_resp = DirectMethodResponse.create_from_method_request( + method_request=m_req, status=FAKE_STATUS, payload=payload + ) + assert isinstance(m_resp, DirectMethodResponse) + assert m_resp.request_id == m_req.request_id + assert m_resp.status == FAKE_STATUS + assert m_resp.payload == payload diff --git a/tests/unit/test_mqtt_client.py b/tests/unit/test_mqtt_client.py new file mode 100644 index 000000000..2dd1db8ec --- /dev/null +++ b/tests/unit/test_mqtt_client.py @@ -0,0 +1,3387 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from azure.iot.device.mqtt_client import ( + MQTTClient, + MQTTError, + MQTTConnectionFailedError, + MQTTConnectionDroppedError, +) +from azure.iot.device.mqtt_client import ( + expected_connect_rc, + expected_subscribe_rc, + expected_unsubscribe_rc, + expected_publish_rc, + expected_on_connect_rc, + expected_on_disconnect_rc, +) +from azure.iot.device.config import ProxyOptions +import paho.mqtt.client as mqtt +import asyncio +import pytest +import threading +import time +from concurrent.futures import ThreadPoolExecutor + + +fake_device_id = "MyDevice" +fake_hostname = "fake.hostname" +fake_password = "fake_password" +fake_username = fake_hostname + "/" + fake_device_id +fake_port = 443 +fake_keepalive = 1234 +fake_ws_path = "/fake/path" +fake_topic = "/some/topic/" +fake_payload = "message content" + +PAHO_STATE_NEW = "NEW" +PAHO_STATE_DISCONNECTED = "DISCONNECTED" +PAHO_STATE_CONNECTED = "CONNECTED" +PAHO_STATE_CONNECTION_LOST = "CONNECTION_LOST" + +UNEXPECTED_PAHO_RC = 255 + +ACK_DELAY = 1 + + +@pytest.fixture(scope="module") +def paho_threadpool(): + # Paho has a single thread it invokes handlers on + tpe = ThreadPoolExecutor(max_workers=1) + yield tpe + tpe.shutdown() + + +@pytest.fixture +def mock_paho(mocker, paho_threadpool): + """This mock is quite a bit more complicated than your average mock in order to + capture some of the weirder Paho behaviors""" + mock_paho = mocker.MagicMock() + # Define a fake internal connection state for Paho. + # You should not ever have to touch this manually. Please don't. + # + # It is further worth noting that this state is different from the one used in + # the real implementation, because Paho doesn't store true connection state, just a + # "desired" connection state (which itself is different from our client's .desire_connection). + # The true connection state is derived by other means (sockets). + # For simplicity, I've rolled all the information relevant to mocking behavior into a + # 4-state value. + mock_paho._state = PAHO_STATE_NEW + # Used to mock out loop_forever behavior. + mock_paho._network_loop_exit = threading.Event() + # Indicates whether or not invocations should automatically trigger callbacks + mock_paho._manual_mode = False + # Indicates whether or not invocations should trigger callbacks immediately + # (i.e. before invocation return) + # NOTE: While the "normal" behavior we can expect is NOT an early ack, we set early ack + # as the default for test performance reasons + mock_paho._early_ack = True + # Default rc value to return on invocations of method mocks + # NOTE: There is no _disconnect_rc because disconnect return values are deterministic + # See the implementation of trigger_on_disconnect and the mock disconnect below. + mock_paho._connect_rc = mqtt.MQTT_ERR_SUCCESS + mock_paho._publish_rc = mqtt.MQTT_ERR_SUCCESS + mock_paho._subscribe_rc = mqtt.MQTT_ERR_SUCCESS + mock_paho._unsubscribe_rc = mqtt.MQTT_ERR_SUCCESS + # Last mid that was returned. Will be incremented over time (see _get_next_mid()) + # NOTE: 0 means no mid has been sent yet + mock_paho._last_mid = 0 + + # Utility helpers + # NOTE: PLEASE USE THESE WHEN WRITING TESTS SO YOU DON'T HAVE TO WORRY ABOUT STATE + def trigger_on_connect(rc=mqtt.CONNACK_ACCEPTED): + if rc == mqtt.CONNACK_ACCEPTED: + # State is only set to connected if successfully connecting + mock_paho._state = PAHO_STATE_CONNECTED + else: + # If it fails it ends up in a "new" state. + mock_paho._state = PAHO_STATE_NEW + if not mock_paho._early_ack: + paho_threadpool.submit(time.sleep, ACK_DELAY) + paho_threadpool.submit( + mock_paho.on_connect, client=mock_paho, userdata=None, flags=None, rc=rc + ) + + mock_paho.trigger_on_connect = trigger_on_connect + + def trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS): + if mock_paho._state == PAHO_STATE_CONNECTED: + mock_paho._state = PAHO_STATE_CONNECTION_LOST + if not mock_paho._early_ack: + paho_threadpool.submit(time.sleep, ACK_DELAY) + # Need to signal that loop_forever will return now (if not already signaled) + if not mock_paho._network_loop_exit.is_set(): + mock_paho._network_loop_exit.set() + paho_threadpool.submit(mock_paho.on_disconnect, client=mock_paho, userdata=None, rc=rc) + + mock_paho.trigger_on_disconnect = trigger_on_disconnect + + def trigger_on_subscribe(mid=None): + if not mid: + mid = mock_paho._last_mid + if not mock_paho._early_ack: + paho_threadpool.submit(time.sleep, ACK_DELAY) + paho_threadpool.submit( + mock_paho.on_subscribe, client=mock_paho, userdata=None, mid=mid, granted_qos=1 + ) + + mock_paho.trigger_on_subscribe = trigger_on_subscribe + + def trigger_on_unsubscribe(mid=None): + if not mid: + mid = mock_paho._last_mid + if not mock_paho._early_ack: + paho_threadpool.submit(time.sleep, ACK_DELAY) + paho_threadpool.submit(mock_paho.on_unsubscribe, client=mock_paho, userdata=None, mid=mid) + + mock_paho.trigger_on_unsubscribe = trigger_on_unsubscribe + + def trigger_on_publish(mid): + if not mid: + mid = mock_paho._last_mid + if not mock_paho._early_ack: + paho_threadpool.submit(time.sleep, ACK_DELAY) + paho_threadpool.submit(mock_paho.on_publish, client=mock_paho, userdata=None, mid=mid) + + mock_paho.trigger_on_publish = trigger_on_publish + + # NOTE: This should not be necessary to use in any tests themselves. + def _get_next_mid(): + mock_paho._last_mid += 1 + mid = mock_paho._last_mid + return mid + + mock_paho._get_next_mid = _get_next_mid + + # Method mocks + def is_connected(*args, **kwargs): + """ + NOT TO BE CONFUSED WITH MQTTClient.is_connected()!!!! + This is Paho's inner state. It returns True even if connection has been lost. + """ + return mock_paho._state != PAHO_STATE_DISCONNECTED + + def loop_forever(*args, **kwargs): + """ + Blocks until network loop exit (on disconnect). + This is necessary as a Future gets made from this method, and whether or not it is + done affects the logic, so we can't just return immediately. + """ + mock_paho._network_loop_exit.clear() + return mock_paho._network_loop_exit.wait() + + def connect(*args, **kwargs): + # Only trigger completion if not in manual mode + # Only trigger completion if returning success + if not mock_paho._manual_mode and mock_paho._connect_rc == mqtt.MQTT_ERR_SUCCESS: + mock_paho.trigger_on_connect() + return mock_paho._connect_rc + + def disconnect(*args, **kwargs): + # NOTE: THERE IS NO WAY TO OVERRIDE THIS RETURN VALUE AS IT IS DETERMINISTIC + # BASED ON THE PAHO STATE + if mock_paho._state == PAHO_STATE_CONNECTED: + mock_paho._state = PAHO_STATE_DISCONNECTED + if not mock_paho._manual_mode: + mock_paho.trigger_on_disconnect() + rc = mqtt.MQTT_ERR_SUCCESS + else: + mock_paho._state = PAHO_STATE_DISCONNECTED + rc = mqtt.MQTT_ERR_NO_CONN + # We don't trigger on_disconnect, but do need to exit network loop if it's running. + # This only happens in cancellation scenarios. + if not mock_paho._network_loop_exit.is_set(): + mock_paho._network_loop_exit.set() + return rc + + def subscribe(*args, **kwargs): + if mock_paho._subscribe_rc != mqtt.MQTT_ERR_SUCCESS: + mid = None + else: + mid = mock_paho._get_next_mid() + if not mock_paho._manual_mode: + mock_paho.trigger_on_subscribe(mid) + return (mock_paho._subscribe_rc, mid) + + def unsubscribe(*args, **kwargs): + if mock_paho._unsubscribe_rc != mqtt.MQTT_ERR_SUCCESS: + mid = None + else: + mid = mock_paho._get_next_mid() + if not mock_paho._manual_mode: + mock_paho.trigger_on_unsubscribe(mid) + return (mock_paho._unsubscribe_rc, mid) + + def publish(*args, **kwargs): + # Unlike subscribe and unsubscribe, publish still returns a mid in the case of failure + mid = mock_paho._get_next_mid() + if not mock_paho._manual_mode: + mock_paho.trigger_on_publish(mid) + # Not going to bother mocking out the details of this message info since we just use it + # for the rc and mid + msg_info = mqtt.MQTTMessageInfo(mid) + msg_info.rc = mock_paho._publish_rc + return msg_info + + mock_paho.is_connected.side_effect = is_connected + mock_paho.loop_forever.side_effect = loop_forever + mock_paho.connect.side_effect = connect + mock_paho.disconnect.side_effect = disconnect + mock_paho.subscribe.side_effect = subscribe + mock_paho.unsubscribe.side_effect = unsubscribe + mock_paho.publish.side_effect = publish + + mocker.patch.object(mqtt, "Client", return_value=mock_paho) + + return mock_paho + + +@pytest.fixture +async def fresh_client(mock_paho): + # NOTE: Implicitly imports the mocked Paho MQTT Client due to patch in mock_paho + client = MQTTClient( + client_id=fake_device_id, hostname=fake_hostname, port=fake_port, auto_reconnect=False + ) + assert client._mqtt_client is mock_paho + yield client + + # Reset any mock paho settings that might affect ability to disconnect + mock_paho._manual_mode = False + await client.disconnect() + + +@pytest.fixture +async def client(fresh_client): + return fresh_client + + +# Helper functions for changing client state. +# +# Always use these to set the state during tests so that the client state and Paho state +# do not get out of sync. +# There's also network loop Futures running in other threads you don't want to have to +# consider when writing a test. +# +# Arguably invoking .cancel() on a task can put things in additional "states", but the tests +# in this module approach cancellation and it's effects as modifications of a state rather than +# itself being a state. The tests themselves should make this clear. +def client_set_connected(client): + """Set the client to a connected state""" + client._connected = True + client._desire_connection = True + client._disconnection_cause = None + client._mqtt_client._state = PAHO_STATE_CONNECTED + # A client after a connection should have a currently running network loop Future + event_loop = asyncio.get_running_loop() + client._network_loop = event_loop.run_in_executor(None, client._mqtt_client.loop_forever) + + +def client_set_disconnected(client): + """Set the client to an (intentionally) disconnected state""" + client._connected = False + client._desire_connection = False + client._disconnection_cause = None + client._mqtt_client._state = PAHO_STATE_DISCONNECTED + # Ensure any running network loop Future exits, then clean up + # An (intentionally) disconnected client should have no network loop Future at all + client._mqtt_client._network_loop_exit.set() + client._network_loop = None + + +def client_set_connection_dropped(client): + """Set the client to a state representing an unexpected disconnect""" + client._connected = False + client._desire_connection = True + client._disconnection_cause = MQTTConnectionDroppedError(rc=7) + client._mqtt_client._state = PAHO_STATE_CONNECTION_LOST + # Ensure any running network loop Future exits. + # A client after a connection drop should have a completed network loop Future + client._mqtt_client._network_loop_exit.set() + if not client._network_loop: + client._network_loop = asyncio.Future() + client._network_loop.set_result(None) + + +def client_set_fresh(client): + """Set a client to a fresh state. + This could either be a client that has never been connected or a client that has had a + connection failure (even if it was previously connected). This is because Paho resets its + state when making a connection attempt. + + FOR ALL INTENTS AND PURPOSES A CLIENT IN THIS STATE SHOULD BEHAVE EXACTLY THE SAME AS + ONE IN A DISCONNECTED STATE. USE THE SAME TESTS FOR BOTH. + """ + client._connected = False + client._desire_connection = False + client._disconnection_cause = None + client._mqtt_client._state = PAHO_STATE_NEW + # Ensure any running network loop Future exits, then clean up + # A fresh client should have no network loop Future at all + client._mqtt_client._network_loop_exit.set() + client._mqtt_client._network_loop_exit.clear() # as if it were never set + client._network_loop = None + + +# Pytest parametrizations +early_ack_params = [ + pytest.param(False, id="Response after invocation returns"), + pytest.param(True, id="Response before invocation returns"), +] + +# NOTE: disconnect rcs are not necessary as disconnect can't fail and the result is deterministic +# (See mock_paho implementation for more information) +# TODO: add raised exception params when we know which ones to expect +connect_failed_rc_params = [ + pytest.param(UNEXPECTED_PAHO_RC, id="Unexpected Paho result"), +] +subscribe_failed_rc_params = [ + pytest.param(mqtt.MQTT_ERR_NO_CONN, id="MQTT_ERR_NO_CONN"), + pytest.param(UNEXPECTED_PAHO_RC, id="Unexpected Paho result"), +] +unsubscribe_failed_rc_params = [ + pytest.param(mqtt.MQTT_ERR_NO_CONN, id="MQTT_ERR_NO_CONN"), + pytest.param(UNEXPECTED_PAHO_RC, id="Unexpected Paho result"), +] +publish_failed_rc_params = [ + # Publish can also return MQTT_ERR_NO_CONN, but it isn't a failure + pytest.param(mqtt.MQTT_ERR_QUEUE_SIZE, id="MQTT_ERR_QUEUE_SIZE"), + pytest.param(UNEXPECTED_PAHO_RC, id="Unexpected Paho result"), +] +on_connect_failed_rc_params = [ + pytest.param(mqtt.CONNACK_REFUSED_PROTOCOL_VERSION, id="CONNACK_REFUSED_PROTOCOL_VERSION"), + pytest.param( + mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED, id="CONNACK_REFUSED_IDENTIFIER_REJECTED" + ), + pytest.param(mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE, id="CONNACK_REFUSED_SERVER_UNAVAILABLE"), + pytest.param( + mqtt.CONNACK_REFUSED_BAD_USERNAME_PASSWORD, id="CONNACK_REFUSED_BAD_USERNAME_PASSWORD" + ), + pytest.param(mqtt.CONNACK_REFUSED_NOT_AUTHORIZED, id="CONNACK_REFUSED_NOT_AUTHORIZED"), + pytest.param( + UNEXPECTED_PAHO_RC, id="Unexpected Paho result" + ), # Reserved for future use defined by MQTT +] +on_disconnect_failed_rc_params = [ + pytest.param(mqtt.MQTT_ERR_CONN_REFUSED, id="MQTT_ERR_CONN_REFUSED"), + pytest.param(mqtt.MQTT_ERR_CONN_LOST, id="MQTT_ERR_CONN_LOST"), + pytest.param(mqtt.MQTT_ERR_KEEPALIVE, id="MQTT_ERR_KEEPALIVE"), + pytest.param(UNEXPECTED_PAHO_RC, id="Unexpected Paho result"), +] + + +# Validate the above are correct so failure will occur if tests are out of date. +def validate_rc_params(rc_params, expected_rc, no_fail=[]): + # Ignore success, and any other failing rcs that don't result in failure + ignore = [mqtt.MQTT_ERR_SUCCESS, mqtt.CONNACK_ACCEPTED] + no_fail + # Assert that all expected rcs (other than ignored vals) are in our rc params + for rc in [v for v in expected_rc if v not in ignore]: + assert True in [rc in param.values for param in rc_params] + # Assert that our unexpected rc stand-in is in our rc params + assert True in [UNEXPECTED_PAHO_RC in param.values for param in rc_params] + # Assert that there are not more values in our rc params than we would expect + expected_len = len(expected_rc) - 1 # No success + expected_len += 1 # We have an additional unexpected value + expected_len -= len(no_fail) # No non-fails + assert len(rc_params) == expected_len + + +validate_rc_params(connect_failed_rc_params, expected_connect_rc) +validate_rc_params(subscribe_failed_rc_params, expected_subscribe_rc) +validate_rc_params(unsubscribe_failed_rc_params, expected_unsubscribe_rc) +validate_rc_params(publish_failed_rc_params, expected_publish_rc, no_fail=[mqtt.MQTT_ERR_NO_CONN]) +validate_rc_params(on_connect_failed_rc_params, expected_on_connect_rc) +validate_rc_params(on_disconnect_failed_rc_params, expected_on_disconnect_rc) + + +############################################################################### +# TESTS START # +############################################################################### + + +@pytest.mark.describe("MQTTClient - Instantiation") +class TestInstantiation: + @pytest.fixture( + params=["HTTP - No Auth", "HTTP - Auth", "SOCKS4", "SOCKS5 - No Auth", "SOCKS5 - Auth"] + ) + def proxy_options(self, request): + if "HTTP" in request.param: + proxy_type = "HTTP" + elif "SOCKS4" in request.param: + proxy_type = "SOCKS4" + else: + proxy_type = "SOCKS5" + + if "No Auth" in request.param: + proxy = ProxyOptions( + proxy_type=proxy_type, proxy_address="fake.address", proxy_port=1080 + ) + else: + proxy = ProxyOptions( + proxy_type=proxy_type, + proxy_address="fake.address", + proxy_port=1080, + proxy_username="fake_username", + proxy_password="fake_password", + ) + return proxy + + @pytest.fixture(params=["TCP", "WebSockets"]) + def transport(self, request): + return request.param.lower() + + @pytest.mark.it("Stores the provided hostname value") + async def test_hostname(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert client._hostname == fake_hostname + + @pytest.mark.it("Stores the provided port value") + async def test_port(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert client._port == fake_port + + @pytest.mark.it("Stores the provided keepalive value (if provided)") + async def test_keepalive(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + keep_alive=fake_keepalive, + ) + assert client._keep_alive == fake_keepalive + + @pytest.mark.it("Stores the provided auto_reconnect value (if provided)") + @pytest.mark.parametrize( + "value", [pytest.param(True, id="Enabled"), pytest.param(False, id="Disabled")] + ) + async def test_auto_reconnect(self, mocker, value): + mocker.patch.object(mqtt, "Client") + client = MQTTClient( + client_id=fake_device_id, hostname=fake_hostname, port=fake_port, auto_reconnect=value + ) + assert client._auto_reconnect == value + + @pytest.mark.it("Stores the provided reconnect_interval value (if provided)") + async def test_reconnect_interval(self, mocker): + mocker.patch.object(mqtt, "Client") + my_interval = 5 + client = MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + auto_reconnect=True, + reconnect_interval=my_interval, + ) + assert client._reconnect_interval == my_interval + + @pytest.mark.it("Creates and stores an instance of the Paho MQTT Client") + async def test_instantiates_mqtt_client(self, mocker, transport): + mock_paho_constructor = mocker.patch.object(mqtt, "Client") + + client = MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + transport=transport, + ) + + assert mock_paho_constructor.call_count == 1 + assert mock_paho_constructor.call_args == mocker.call( + client_id=fake_device_id, + clean_session=False, + protocol=mqtt.MQTTv311, + transport=transport, + reconnect_on_failure=False, + ) + assert client._mqtt_client is mock_paho_constructor.return_value + + @pytest.mark.it("Uses the provided SSLContext with the Paho MQTT Client") + async def test_ssl_context(self, mocker, transport): + mock_paho = mocker.patch.object(mqtt, "Client").return_value + mock_ssl_context = mocker.MagicMock() + + MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + transport=transport, + ssl_context=mock_ssl_context, + ) + + assert mock_paho.tls_set_context.call_count == 1 + assert mock_paho.tls_set_context.call_args == mocker.call(context=mock_ssl_context) + + @pytest.mark.it( + "Uses a default SSLContext with the Paho MQTT Client if no SSLContext is provided" + ) + async def test_ssl_context_default(self, mocker, transport): + mock_paho = mocker.patch.object(mqtt, "Client").return_value + + MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + transport=transport, + ) + + # NOTE: calling tls_set_context with None == using default context + assert mock_paho.tls_set_context.call_count == 1 + assert mock_paho.tls_set_context.call_args == mocker.call(context=None) + + @pytest.mark.it("Sets proxy using the provided ProxyOptions with the Paho MQTT Client") + async def test_proxy_options(self, mocker, proxy_options, transport): + mock_paho = mocker.patch.object(mqtt, "Client").return_value + + MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + transport=transport, + proxy_options=proxy_options, + ) + + # Verify proxy has been set + assert mock_paho.proxy_set.call_count == 1 + assert mock_paho.proxy_set.call_args == mocker.call( + proxy_type=proxy_options.proxy_type_socks, + proxy_addr=proxy_options.proxy_address, + proxy_port=proxy_options.proxy_port, + proxy_username=proxy_options.proxy_username, + proxy_password=proxy_options.proxy_password, + ) + + @pytest.mark.it("Does not set any proxy on the Paho MQTT Client if no ProxyOptions is provided") + async def test_no_proxy_options(self, mocker, transport): + mock_paho = mocker.patch.object(mqtt, "Client").return_value + + MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + transport=transport, + ) + + # Proxy was not set + assert mock_paho.proxy_set.call_count == 0 + + @pytest.mark.it( + "Sets the websockets path on the Paho MQTT Client using the provided value if using websockets" + ) + async def test_ws_path(self, mocker): + mock_paho = mocker.patch.object(mqtt, "Client").return_value + + MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + transport="websockets", + websockets_path=fake_ws_path, + ) + + # Websockets path was set + assert mock_paho.ws_set_options.call_count == 1 + assert mock_paho.ws_set_options.call_args == mocker.call(path=fake_ws_path) + + @pytest.mark.it("Does not set the websocket path on the Paho MQTT Client if it is not provided") + async def test_no_ws_path(self, mocker): + mock_paho = mocker.patch.object(mqtt, "Client").return_value + + MQTTClient( + client_id=fake_device_id, hostname=fake_hostname, port=fake_port, transport="websockets" + ) + + # Websockets path was not set + assert mock_paho.ws_set_options.call_count == 0 + + @pytest.mark.it( + "Does not set the websocket path on the Paho MQTT Client if not using websockets" + ) + async def test_ws_path_no_ws(self, mocker): + mock_paho = mocker.patch.object(mqtt, "Client").return_value + + MQTTClient( + client_id=fake_device_id, + hostname=fake_hostname, + port=fake_port, + transport="tcp", + websockets_path=fake_ws_path, + ) + + # Websockets path was not set + assert mock_paho.ws_set_options.call_count == 0 + + @pytest.mark.it("Sets the initial connection state") + async def test_connection_state(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert not client._connected + assert not client._desire_connection + + @pytest.mark.it("Sets the previous disconnection cause to None") + async def test_disconnection_cause(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert client._disconnection_cause is None + + @pytest.mark.it("Sets the network loop Future to None") + async def test_network_loop(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert client._network_loop is None + + @pytest.mark.it("Sets the reconnect daemon Task to None") + async def test_reconnect_daemon(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert client._reconnect_daemon is None + + @pytest.mark.it("Sets initial operation tracking structures") + async def test_pending_ops(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert client._pending_subs == {} + assert client._pending_unsubs == {} + assert client._pending_pubs == {} + + @pytest.mark.it("Creates an incoming message queue") + async def test_incoming_messages_unfiltered(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert isinstance(client._incoming_messages, asyncio.Queue) + assert client._incoming_messages.empty() + + @pytest.mark.it("Sets initial filtered message queue structures") + async def test_incoming_messages_filtered(self, mocker): + mocker.patch.object(mqtt, "Client") + client = MQTTClient(client_id=fake_device_id, hostname=fake_hostname, port=fake_port) + assert client._incoming_filtered_messages == {} + + # TODO: May need public conditions tests (assuming they stay public) + + +@pytest.mark.describe("MQTTClient - .set_credentials()") +class TestSetCredentials: + @pytest.mark.it("Sets a username only") + def test_username(self, client, mock_paho, mocker): + assert mock_paho.username_pw_set.call_count == 0 + + client.set_credentials(fake_username) + + assert mock_paho.username_pw_set.call_count == 1 + assert mock_paho.username_pw_set.call_args == mocker.call( + username=fake_username, password=None + ) + + @pytest.mark.it("Sets a username and password combination") + def test_username_password(self, client, mock_paho, mocker): + assert mock_paho.username_pw_set.call_count == 0 + + client.set_credentials(fake_username, fake_password) + + assert mock_paho.username_pw_set.call_count == 1 + assert mock_paho.username_pw_set.call_args == mocker.call( + username=fake_username, password=fake_password + ) + + +@pytest.mark.describe("MQTTClient - .is_connected()") +class TestIsConnected: + @pytest.mark.it("Returns a boolean indicating the connection status") + @pytest.mark.parametrize( + "state, expected_value", + [ + pytest.param("Connected", True, id="Connected"), + pytest.param("Disconnected", False, id="Disconnected"), + pytest.param("Fresh", False, id="Fresh"), + pytest.param("Connection Dropped", False, id="Connection Dropped"), + ], + ) + async def test_returns_value(self, client, state, expected_value): + if state == "Connected": + client_set_connected(client) + elif state == "Disconnected": + client_set_disconnected(client) + elif state == "Fresh": + client_set_fresh(client) + elif state == "Connection Dropped": + client_set_connection_dropped + assert client.is_connected() == expected_value + + +@pytest.mark.describe("MQTTClient - .previous_disconnection_cause()") +class TestPreviousDisconnectionCause: + @pytest.mark.it("Returns the exception that caused the previous disconnection (if any)") + @pytest.mark.parametrize( + "state, expected_exc_type", + [ + pytest.param("Connected", type(None), id="Connected"), + pytest.param("Disconnected", type(None), id="Disconnected"), + pytest.param("Fresh", type(None), id="Fresh"), + pytest.param("Connection Dropped", MQTTConnectionDroppedError, id="Connection Dropped"), + ], + ) + async def test_returns_value(self, client, state, expected_exc_type): + if state == "Connected": + client_set_connected(client) + elif state == "Disconnected": + client_set_disconnected(client) + elif state == "Fresh": + client_set_fresh(client) + elif state == "Connection Dropped": + client_set_connection_dropped(client) + assert isinstance(client.previous_disconnection_cause(), expected_exc_type) + + +@pytest.mark.describe("MQTTClient - .add_incoming_message_filter()") +class TestAddIncomingMessageFilter: + @pytest.mark.it("Adds a new incoming message queue for the given topic") + def test_adds_queue(self, client): + assert len(client._incoming_filtered_messages) == 0 + + client.add_incoming_message_filter(fake_topic) + + assert len(client._incoming_filtered_messages) == 1 + assert isinstance(client._incoming_filtered_messages[fake_topic], asyncio.Queue) + assert client._incoming_filtered_messages[fake_topic].empty() + + @pytest.mark.it("Adds a callback for the given topic to the Paho MQTT Client") + def test_adds_callback(self, mocker, client, mock_paho): + assert mock_paho.message_callback_add.call_count == 0 + + client.add_incoming_message_filter(fake_topic) + + assert mock_paho.message_callback_add.call_count == 1 + assert mock_paho.message_callback_add.call_args == mocker.call(fake_topic, mocker.ANY) + + @pytest.mark.it( + "Raises a ValueError and does not add an incoming message queue or add a callback to the Paho MQTT Client if the filter already exists" + ) + def test_filter_exists(self, client, mock_paho): + client.add_incoming_message_filter(fake_topic) + assert fake_topic in client._incoming_filtered_messages + assert len(client._incoming_filtered_messages) == 1 + existing_queue = client._incoming_filtered_messages[fake_topic] + assert existing_queue.empty() + assert mock_paho.message_callback_add.call_count == 1 + + # Try and add the same topic filter again + with pytest.raises(ValueError): + client.add_incoming_message_filter(fake_topic) + + # No additional filter was added, nor were changes made to the existing one + assert fake_topic in client._incoming_filtered_messages + assert len(client._incoming_filtered_messages) == 1 + assert client._incoming_filtered_messages[fake_topic] == existing_queue + assert existing_queue.empty() + assert mock_paho.message_callback_add.call_count == 1 + + # NOTE: To see this filter in action, see the message receive tests + + +@pytest.mark.describe("MQTTClient - .remove_incoming_message_filter()") +class TestRemoveIncomingMessageFilter: + @pytest.mark.it("Removes the callback for the given topic from the Paho MQTT Client") + def test_removes_callback(self, mocker, client, mock_paho): + # Add a filter + client.add_incoming_message_filter(fake_topic) + assert mock_paho.message_callback_remove.call_count == 0 + + # Remove + client.remove_incoming_message_filter(fake_topic) + + # Callback was removed + assert mock_paho.message_callback_remove.call_count == 1 + assert mock_paho.message_callback_remove.call_args == mocker.call(fake_topic) + + @pytest.mark.it("Removes the incoming message queue for the given topic") + def test_removes_queue(self, client): + # Add a filter + client.add_incoming_message_filter(fake_topic) + assert fake_topic in client._incoming_filtered_messages + + # Remove + client.remove_incoming_message_filter(fake_topic) + + # Filter queue was removed + assert fake_topic not in client._incoming_filtered_messages + + @pytest.mark.it( + "Raises ValueError and does not remove any incoming message queues or remove any callbacks from the Paho MQTT Client if the filter does not exist" + ) + async def test_filter_does_not_exist(self, mocker, client, mock_paho): + # Add a different filter + client.add_incoming_message_filter(fake_topic) + assert fake_topic in client._incoming_filtered_messages + assert len(client._incoming_filtered_messages) == 1 + existing_queue = client._incoming_filtered_messages[fake_topic] + fake_item = mocker.MagicMock() + await existing_queue.put(fake_item) + assert existing_queue.qsize() == 1 + assert mock_paho.message_callback_remove.call_count == 0 + + # Remove a topic that has not yet been added + even_faker_topic = "even/faker/topic" + assert even_faker_topic != fake_topic + with pytest.raises(ValueError): + client.remove_incoming_message_filter(even_faker_topic) + + # No filter was removed or modified + assert fake_topic in client._incoming_filtered_messages + assert len(client._incoming_filtered_messages) == 1 + existing_queue = client._incoming_filtered_messages[fake_topic] + assert existing_queue.qsize() == 1 + item = await existing_queue.get() + assert item is fake_item + assert mock_paho.message_callback_remove.call_count == 0 + + +@pytest.mark.describe("MQTTClient - .get_incoming_message_generator()") +class TestGetIncomingMessageGenerator: + @pytest.mark.it( + "Returns a generator that yields items from the default incoming message queue if no filter topic is provided" + ) + async def test_default_generator(self, client): + # Get generator + incoming_messages = client.get_incoming_message_generator() + + # Add items to queue + item1 = mqtt.MQTTMessage(mid=1) + item2 = mqtt.MQTTMessage(mid=2) + item3 = mqtt.MQTTMessage(mid=3) + await client._incoming_messages.put(item1) + await client._incoming_messages.put(item2) + await client._incoming_messages.put(item3) + + # Use generator + result = await incoming_messages.__anext__() + assert result is item1 + result = await incoming_messages.__anext__() + assert result is item2 + result = await incoming_messages.__anext__() + assert result is item3 + + @pytest.mark.it( + "Returns a generator that yields items from a filtered incoming message queue if a filter topic is provided" + ) + async def test_filtered_generator(self, client): + # Add a filter, and get the generator + client.add_incoming_message_filter(fake_topic) + incoming_messages = client.get_incoming_message_generator(fake_topic) + + # Add items to queue + item1 = mqtt.MQTTMessage(mid=1) + item2 = mqtt.MQTTMessage(mid=2) + item3 = mqtt.MQTTMessage(mid=3) + await client._incoming_filtered_messages[fake_topic].put(item1) + await client._incoming_filtered_messages[fake_topic].put(item2) + await client._incoming_filtered_messages[fake_topic].put(item3) + + # Use generator + result = await incoming_messages.__anext__() + assert result is item1 + result = await incoming_messages.__anext__() + assert result is item2 + result = await incoming_messages.__anext__() + assert result is item3 + + @pytest.mark.it("Raises a ValueError if a filter has not been added for the given filter topic") + async def test_no_filter_added(self, client): + assert fake_topic not in client._incoming_filtered_messages + + with pytest.raises(ValueError): + client.get_incoming_message_generator(fake_topic) + + +# NOTE: Because clients in Disconnected, Connection Dropped, and Fresh states have the same +# behaviors during a connect, define a parent class that can be subclassed so tests don't have +# to be written twice. +class ConnectWithClientNotConnectedTests: + @pytest.mark.it( + "Starts the reconnect daemon and stores its task if auto_reconnect is enabled and the daemon is not yet running" + ) + async def test_reconnect_daemon_enabled_not_running(self, client): + client._auto_reconnect = True + assert client._reconnect_daemon is None + + await client.connect() + + assert isinstance(client._reconnect_daemon, asyncio.Task) + assert not client._reconnect_daemon.done() + + @pytest.mark.it("Does not start the reconnect daemon if auto_reconnect is disabled") + async def test_reconnect_daemon_disabled(self, client): + assert client._auto_reconnect is False + assert client._reconnect_daemon is None + + await client.connect() + + assert client._reconnect_daemon is None + + @pytest.mark.it("Does not start the reconnect daemon if it is already running") + async def test_reconnect_daemon_running(self, mocker, client): + client._auto_reconnect = True + mock_task = mocker.MagicMock() + client._reconnect_daemon = mock_task + + await client.connect() + + assert client._reconnect_daemon is mock_task + + @pytest.mark.it("Invokes an MQTT connect via Paho using stored values") + async def test_paho_invocation(self, mocker, client, mock_paho): + assert mock_paho.connect.call_count == 0 + + await client.connect() + + assert mock_paho.connect.call_count == 1 + assert mock_paho.connect.call_args == mocker.call( + host=client._hostname, port=client._port, keepalive=client._keep_alive + ) + + @pytest.mark.it( + "Raises a MQTTConnectionFailedError (non-fatal) if an exception is raised while invoking Paho's connect" + ) + async def test_fail_paho_invocation(self, client, mock_paho, arbitrary_exception): + mock_paho.connect.side_effect = arbitrary_exception + + with pytest.raises(MQTTConnectionFailedError) as e_info: + await client.connect() + assert e_info.value.__cause__ is arbitrary_exception + assert e_info.value.rc is None + assert not e_info.value.fatal + + # NOTE: This should be an invalid scenario as connect should not be able to return a failed return code + @pytest.mark.it( + "Raises a MQTTConnectionFailedError (non-fatal) if invoking Paho's connect returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", connect_failed_rc_params) + async def test_fail_status(self, client, mock_paho, failing_rc): + mock_paho._connect_rc = failing_rc + + with pytest.raises(MQTTConnectionFailedError) as e_info: + await client.connect() + assert e_info.value.rc is None + assert not e_info.value.fatal + cause = e_info.value.__cause__ + assert isinstance(cause, MQTTError) + assert cause.rc == failing_rc + + @pytest.mark.it( + "Starts the Paho network loop if the connect invocation is successful and the network loop is not already running" + ) + async def test_network_loop_connect_success(self, client, mock_paho): + assert not client._network_loop_running() + assert mock_paho.loop_forever.call_count == 0 + + await client.connect() + # Due to the way test infrastructure triggers CONNACK, it's possible for .connect() to + # return before the mock network loop has started. This isn't a concern in real usage, + # since CONNACK cannot be received until the network loop is running. + await asyncio.sleep(0.1) + + assert mock_paho.loop_forever.call_count == 1 + assert isinstance(client._network_loop, asyncio.Future) + assert client._network_loop_running() + + @pytest.mark.it("Does not start the Paho network loop if the connect invocation raises") + async def test_network_loop_connect_fail_raise(self, client, mock_paho, arbitrary_exception): + assert not client._network_loop_running() + assert mock_paho.loop_forever.call_count == 0 + mock_paho.connect.side_effect = arbitrary_exception + + with pytest.raises(MQTTConnectionFailedError): + await client.connect() + + assert mock_paho.loop_forever.call_count == 0 + assert not client._network_loop_running() + + # NOTE: This should be an invalid scenario as connect should not be able to return a failed return code + @pytest.mark.it( + "Does not start the Paho network loop if the connect invocation returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", connect_failed_rc_params) + async def test_network_loop_connect_fail_status(self, client, mock_paho, failing_rc): + assert not client._network_loop_running() + assert mock_paho.loop_forever.call_count == 0 + mock_paho._connect_rc = failing_rc + + with pytest.raises(MQTTConnectionFailedError): + await client.connect() + + assert mock_paho.loop_forever.call_count == 0 + assert not client._network_loop_running() + + # NOTE: This is not common, but possible due to cancellation. See more in the cancellation tests. + # Admittedly, this can only really happen in the "Client Fresh" state, but we test it for all. + @pytest.mark.it("Does not start the Paho network loop if it is already running") + async def test_network_loop_already_running(self, client, mock_paho): + event_loop = asyncio.get_running_loop() + client._network_loop = event_loop.run_in_executor(None, client._mqtt_client.loop_forever) + assert not client._network_loop.done() + assert mock_paho.loop_forever.call_count == 1 + + await client.connect() + + assert not client._network_loop.done() + assert mock_paho.loop_forever.call_count == 1 # Same as it was before + + @pytest.mark.it( + "Waits to return until Paho receives a success response if the connect invocation succeeded" + ) + async def test_waits_for_completion(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a connect. It won't complete + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.5) + assert not connect_task.done() + + # Trigger connect completion + mock_paho.trigger_on_connect(rc=mqtt.CONNACK_ACCEPTED) + await connect_task + + @pytest.mark.it( + "Raises a MQTTConnectionFailedError (non-fatal) if the connect attempt receives a failure response" + ) + @pytest.mark.parametrize("failing_rc", on_connect_failed_rc_params) + async def test_fail_response(self, client, mock_paho, failing_rc): + # Require manual completion + mock_paho._manual_mode = True + + # Attempt connect + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + + # Send failure CONNACK response + mock_paho.trigger_on_connect(rc=failing_rc) + # Any CONNACK failure also results in a ERR_CONN_REFUSED to on_disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_REFUSED) + with pytest.raises(MQTTConnectionFailedError) as e_info: + await connect_task + assert e_info.value.rc == failing_rc + assert not e_info.value.fatal + + @pytest.mark.it("Can handle responses received before or after Paho invocation returns") + @pytest.mark.parametrize("early_ack", early_ack_params) + async def test_early_ack(self, client, mock_paho, early_ack): + mock_paho._early_ack = early_ack + await client.connect() + # If this doesn't hang, the test passes + + @pytest.mark.it("Puts the client in a connected state if connection attempt is successful") + async def test_state_success(self, client): + assert not client.is_connected() + + await client.connect() + + assert client.is_connected() + + # NOTE: Technically, there can only really be a previous cause in the Connection Dropped case + # but we'll test it against all cases + @pytest.mark.it( + "Clears the previous disconnection cause (if any) if connection attempt is successful" + ) + @pytest.mark.parametrize( + "prev_cause", + [ + pytest.param(MQTTConnectionDroppedError(rc=7), id="Previous disconnection cause"), + pytest.param(None, id="No previous disconnection cause"), + ], + ) + async def test_disconnection_cause_clear_success(self, client, prev_cause): + client._disconnection_cause = prev_cause + assert client.previous_disconnection_cause() is prev_cause + + await client.connect() + + assert client._disconnection_cause is None + assert client.previous_disconnection_cause() is None + + @pytest.mark.it( + "Leaves the client in a disconnected state if an exception is raised while invoking Paho's connect" + ) + async def test_state_fail_raise(self, client, mock_paho, arbitrary_exception): + # Raise failure from connect + mock_paho.connect.side_effect = arbitrary_exception + assert not client.is_connected() + + with pytest.raises(MQTTConnectionFailedError): + await client.connect() + + assert not client.is_connected() + + # NOTE: This should be an invalid scenario as connect should not be able to return a failed return code + @pytest.mark.it( + "Leaves the client in a disconnected state if invoking Paho's connect returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", connect_failed_rc_params) + async def test_state_fail_status(self, client, mock_paho, failing_rc): + # Return a fail + mock_paho._connect_rc = failing_rc + assert not client.is_connected() + + with pytest.raises(MQTTConnectionFailedError): + await client.connect() + + assert not client.is_connected() + + @pytest.mark.it( + "Leaves the client in a disconnected state if the connect attempt receives a failure response" + ) + @pytest.mark.parametrize("failing_rc", on_connect_failed_rc_params) + async def test_state_fail_response(self, client, mock_paho, failing_rc): + # Require manual completion + mock_paho._manual_mode = True + assert not client.is_connected() + + # Attempt connect + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # Send failure CONNACK response + mock_paho.trigger_on_connect(rc=failing_rc) + # Any CONNACK failure also results in an ERR_CONN_REFUSED to on_disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_REFUSED) + + with pytest.raises(MQTTConnectionFailedError): + await connect_task + + assert not client.is_connected() + + @pytest.mark.it( + "Leaves the reconnect daemon running if an exception is raised while invoking Paho's connect" + ) + async def test_reconnect_daemon_fail_raise(self, client, mock_paho, arbitrary_exception): + client._auto_reconnect = True + assert client._reconnect_daemon is None + # Raise failure from connect + mock_paho.connect.side_effect = arbitrary_exception + + with pytest.raises(MQTTConnectionFailedError): + await client.connect() + + assert isinstance(client._reconnect_daemon, asyncio.Task) + assert not client._reconnect_daemon.done() + + # NOTE: This should be an invalid scenario as connect should not be able to return a failed return code + @pytest.mark.it( + "Leaves the reconnect daemon running if invoking Paho's connect returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", connect_failed_rc_params) + async def test_reconnect_daemon_fail_status(self, client, mock_paho, failing_rc): + # Return a fail + mock_paho._connect_rc = failing_rc + client._auto_reconnect = True + assert client._reconnect_daemon is None + + with pytest.raises(MQTTConnectionFailedError): + await client.connect() + + assert isinstance(client._reconnect_daemon, asyncio.Task) + assert not client._reconnect_daemon.done() + + @pytest.mark.it( + "Leaves the reconnect daemon running if the connect attempt receives a failure response" + ) + @pytest.mark.parametrize("failing_rc", on_connect_failed_rc_params) + async def test_reconnect_daemon_fail_response(self, client, mock_paho, failing_rc): + # Require manual completion + mock_paho._manual_mode = True + client._auto_reconnect = True + assert client._reconnect_daemon is None + + # Attempt connect + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # Send failure CONNACK response + mock_paho.trigger_on_connect(rc=failing_rc) + # Any CONNACK failure also results in an ERR_CONN_REFUSED to on_disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_REFUSED) + + with pytest.raises(MQTTConnectionFailedError): + await connect_task + + assert isinstance(client._reconnect_daemon, asyncio.Task) + assert not client._reconnect_daemon.done() + # Some test cases will need some help cleaning up + # TODO: is there a cleaner way to make sure this happens smoothly? + # TODO: the issue is I think that connect is getting called by the task before it can get cleaned + client._reconnect_daemon.cancel() + + @pytest.mark.it( + "Clears the completed network loop Future if the connect attempt receives a failure response" + ) + @pytest.mark.parametrize("failing_rc", on_connect_failed_rc_params) + async def test_network_loop_fail_response(self, client, mock_paho, failing_rc): + # Require manual completion + mock_paho._manual_mode = True + + # NOTE: network loop may or may not already be running depending on state + + # Attempt connect + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + + # Network Loop is running + network_loop_future = client._network_loop + assert isinstance(network_loop_future, asyncio.Future) + assert not network_loop_future.done() + + # Send failure CONNACK response + mock_paho.trigger_on_connect(rc=failing_rc) + # Any CONNACK failure also results in an ERR_CONN_REFUSED to on_disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_REFUSED) + + with pytest.raises(MQTTConnectionFailedError): + await connect_task + + # Network Loop future completed, and was cleared + assert network_loop_future.done() + assert client._network_loop is None + + @pytest.mark.it( + "Raises CancelledError if cancelled while waiting for the Paho invocation to return" + ) + async def test_cancel_waiting_paho_invocation(self, client, mock_paho): + # Create a fake connect implementation that doesn't return right away + finish_connect = threading.Event() + waiting_on_paho = True + + def fake_connect(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_connect.wait() + waiting_on_paho = False + # mock_paho.trigger_on_connect(rc=mqtt.CONNACK_ACCEPTED) + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.connect.side_effect = fake_connect + + # Start a connect task that will hang on Paho invocation + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + assert not connect_task.done() + # Paho invocation has not returned + assert waiting_on_paho + + # Cancel task + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + + # Allow the fake implementation to finish + finish_connect.set() + + @pytest.mark.it( + "Stops a reconnect daemon that was started on this current connect when cancelled while waiting for the Paho invocation to return" + ) + async def test_cancel_reconnect_daemon_current_connect_waiting_invoke(self, client, mock_paho): + # Create a fake connect implementation that doesn't return right away + finish_connect = threading.Event() + waiting_on_paho = True + + def fake_connect(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_connect.wait() + waiting_on_paho = False + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.connect.side_effect = fake_connect + + # No reconnect daemon has started + client._auto_reconnect = True + assert client._reconnect_daemon is None + + # Start a connect task that will hang on Paho invocation + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + assert not connect_task.done() + # Paho invocation has not returned + assert waiting_on_paho + + # Reconnect daemon has been started + assert client._reconnect_daemon is not None + daemon_task = client._reconnect_daemon + assert not daemon_task.done() + + # Cancel task + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + + # Daemon task was completed and removed + assert client._reconnect_daemon is None + assert daemon_task.done() + + # Allow the fake implementation to finish + finish_connect.set() + + @pytest.mark.it( + "Does not stop a reconnect daemon that was started on a previous connect when cancelled while waiting for the Paho invocation to return" + ) + async def test_cancel_reconnect_daemon_previous_connect_waiting_invoke( + self, mocker, client, mock_paho + ): + # Create a fake connect implementation that doesn't return right away + finish_connect = threading.Event() + waiting_on_paho = True + + def fake_connect(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_connect.wait() + waiting_on_paho = False + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.connect.side_effect = fake_connect + + # Reconnect daemon is already running + client._auto_reconnect = True + daemon_task = mocker.MagicMock() + client._reconnect_daemon = daemon_task + + # Start a connect task that will hang on Paho invocation + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + assert not connect_task.done() + # Paho invocation has not returned + assert waiting_on_paho + + # Reconnect daemon has not been altered + assert client._reconnect_daemon is daemon_task + assert daemon_task.cancel.call_count == 0 + + # Cancel task + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + + # Daemon task was unaffected + assert client._reconnect_daemon is daemon_task + assert daemon_task.cancel.call_count == 0 + + # Allow the fake implementation to finish + finish_connect.set() + + # NOTE: This test differs from the ones seen in Pub/Sub/Unsub because pending operations + # in a connect don't indicate the same thing they with the others. Instead we hack the mock + # some more to prove the expected behavior + @pytest.mark.it("Raises CancelledError if cancelled while waiting for a response") + async def test_cancel_waiting_response(self, client, mock_paho): + paho_invoke_done = False + + def fake_connect(*args, **kwargs): + nonlocal paho_invoke_done + paho_invoke_done = True + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.connect.side_effect = fake_connect + + # Start a connect task + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # We are now waiting for a response + assert not connect_task.done() + assert paho_invoke_done + + # Cancel the connect task + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + + @pytest.mark.it( + "Stops a reconnect daemon that was started on this current connect when cancelled while waiting for a response" + ) + async def test_cancel_reconnect_daemon_current_connect_waiting_response( + self, client, mock_paho + ): + paho_invoke_done = False + + def fake_connect(*args, **kwargs): + nonlocal paho_invoke_done + paho_invoke_done = True + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.connect.side_effect = fake_connect + + # No reconnect daemon has started + client._auto_reconnect = True + assert client._reconnect_daemon is None + + # Start a connect task + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # We are now waiting for a response + assert not connect_task.done() + assert paho_invoke_done + + # Reconnect daemon has been started + assert client._reconnect_daemon is not None + daemon_task = client._reconnect_daemon + assert not daemon_task.done() + + # Cancel the connect task + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + + # Daemon task was completed and removed + assert client._reconnect_daemon is None + assert daemon_task.done() + + @pytest.mark.it( + "Does not stop a reconnect daemon that was started on a previous connect when cancelled while waiting for a response" + ) + async def test_cancel_reconnect_daemon_previous_connect_waiting_response( + self, mocker, client, mock_paho + ): + paho_invoke_done = False + + def fake_connect(*args, **kwargs): + nonlocal paho_invoke_done + paho_invoke_done = True + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.connect.side_effect = fake_connect + + # Reconnect daemon is already running + client._auto_reconnect = True + daemon_task = mocker.MagicMock() + client._reconnect_daemon = daemon_task + + # Start a connect task + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # We are now waiting for a response + assert not connect_task.done() + assert paho_invoke_done + + # Reconnect daemon has not been altered + assert client._reconnect_daemon is daemon_task + assert daemon_task.cancel.call_count == 0 + + # Cancel the connect task + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + + # Daemon task was unaffected + assert client._reconnect_daemon is daemon_task + assert daemon_task.cancel.call_count == 0 + + +@pytest.mark.describe("MQTTClient - .connect() -- Client Fresh") +class TestConnectWithClientFresh(ConnectWithClientNotConnectedTests): + @pytest.fixture + async def client(self, fresh_client): + return fresh_client + + +@pytest.mark.describe("MQTTClient - .connect() -- Client Disconnected") +class TestConnectWithClientDisconnected(ConnectWithClientNotConnectedTests): + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_disconnected(client) + return client + + +@pytest.mark.describe("MQTTClient - .connect() -- Client Connection Dropped") +class TestConnectWithClientConnectionDropped(ConnectWithClientNotConnectedTests): + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_connection_dropped(client) + return client + + +@pytest.mark.describe("MQTTClient - .connect() -- Client Already Connected") +class TestConnectWithClientConnected: + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_connected(client) + return client + + @pytest.mark.it("Does not invoke an MQTT connect via Paho") + async def test_paho_invocation(self, client, mock_paho): + assert mock_paho.connect.call_count == 0 + + await client.connect() + + assert mock_paho.connect.call_count == 0 + + @pytest.mark.it("Does not start the reconnect daemon") + async def test_reconnect_daemon(self, client): + client._auto_reconnect = True + assert client._reconnect_daemon is None + + await client.connect() + + assert client._reconnect_daemon is None + + @pytest.mark.it("Does not start the Paho network loop") + async def test_network_loop(self, client, mock_paho): + # loop is already running due to being connected + assert client._network_loop_running() + assert mock_paho.loop_forever.call_count == 1 + + await client.connect() + + assert client._network_loop_running() + assert mock_paho.loop_forever.call_count == 1 # unchanged + + @pytest.mark.it("Leaves the client in a connected state") + async def test_state(self, client): + assert client.is_connected() + + await client.connect() + + assert client.is_connected() + + @pytest.mark.it("Leaves the disconnection cause set to None") + async def test_disconnection_cause(self, client): + assert client.previous_disconnection_cause() is None + + await client.connect() + + assert client.previous_disconnection_cause() is None + + @pytest.mark.it("Does not wait for a response before returning") + async def test_return(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + # Attempt connect + connect_task = asyncio.create_task(client.connect()) + # No waiting for connect response trigger was required + await connect_task + + +# NOTE: Disconnect responses can be either a single or double invocation of Paho's .on_disconnect() +# handler. Both cases are covered. Why does this happen? I don't know. But it does, and the client +# is designed to handle it. +# NOTE: Paho's .disconnect() method will always return success (rc = MQTT_ERR_SUCCESS) when the +# client is connected. As such, we don't have to test rc != MQTT_ERR_SUCCESS here +# (it is covered in other test classes) +@pytest.mark.describe("MQTTClient - .disconnect() -- Client Connected") +class TestDisconnectWithClientConnected: + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_connected(client) + return client + + @pytest.mark.it("Invokes an MQTT disconnect via Paho") + async def test_paho_invocation(self, mocker, client, mock_paho): + assert mock_paho.disconnect.call_count == 0 + + await client.disconnect() + + assert mock_paho.disconnect.call_count == 1 + assert mock_paho.disconnect.call_args == mocker.call() + + @pytest.mark.it("Waits to return until Paho receives a response and the network loop exits") + @pytest.mark.parametrize( + "double_response", + [ + pytest.param(False, id="Single Disconnect Response"), + pytest.param(True, id="Double Disconnect Response"), + ], + ) + async def test_waits_for_completion(self, client, mock_paho, double_response): + # Require manual completion + mock_paho._manual_mode = True + + # Start a disconnect. It won't complete + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.5) + assert not disconnect_task.done() + network_loop_future = client._network_loop + assert isinstance(network_loop_future, asyncio.Future) + assert not network_loop_future.done() + + # Trigger disconnect completion + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + if double_response: + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + assert network_loop_future.done() + + @pytest.mark.it("Can handle responses received before or after Paho invocation returns") + @pytest.mark.parametrize("early_ack", early_ack_params) + async def test_early_ack(self, client, mock_paho, early_ack): + mock_paho._early_ack = early_ack + await client.disconnect() + # If this doesn't hang, the test passes + + @pytest.mark.it("Cancels and removes the reconnect daemon task if it is running") + @pytest.mark.parametrize( + "double_response", + [ + pytest.param(False, id="Single Disconnect Response"), + pytest.param(True, id="Double Disconnect Response"), + ], + ) + async def test_reconnect_daemon(self, mocker, client, mock_paho, double_response): + # Require manual completion + mock_paho._manual_mode = True + # Set a fake daemon task + mock_task = mocker.MagicMock() + client._reconnect_daemon = mock_task + + # Start a disconnect. + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Trigger disconnect completion + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + if double_response: + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + + # Daemon was cancelled + assert mock_task.cancel.call_count == 1 + assert client._reconnect_daemon is None + + @pytest.mark.it("Puts the client in a disconnected state") + @pytest.mark.parametrize( + "double_response", + [ + pytest.param(False, id="Single Disconnect Response"), + pytest.param(True, id="Double Disconnect Response"), + ], + ) + async def test_state(self, client, mock_paho, double_response): + # Require manual completion + mock_paho._manual_mode = True + assert client.is_connected() + + # Start a disconnect. + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Trigger disconnect completion + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + if double_response: + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + + assert not client.is_connected() + + @pytest.mark.it("Does not set a disconnection cause") + @pytest.mark.parametrize( + "double_response", + [ + pytest.param(False, id="Single Disconnect Response"), + pytest.param(True, id="Double Disconnect Response"), + ], + ) + async def test_disconnection_cause(self, client, mock_paho, double_response): + # Require manual completion + mock_paho._manual_mode = True + assert client.previous_disconnection_cause() is None + + # Start a disconnect. + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Trigger disconnect completion + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + if double_response: + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + + assert client.previous_disconnection_cause() is None + + @pytest.mark.it("Cancels and removes all pending subscribes and unsubscribes") + @pytest.mark.parametrize( + "double_response", + [ + pytest.param(False, id="Single Disconnect Response"), + pytest.param(True, id="Double Disconnect Response"), + ], + ) + async def test_cancel_sub_unsub(self, mocker, client, mock_paho, double_response): + # Require manual completion + mock_paho._manual_mode = True + # Set mocked pending Futures + mock_subs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + mock_unsubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_subs[1] = mock_subs[0] + client._pending_subs[2] = mock_subs[1] + client._pending_subs[3] = mock_subs[2] + client._pending_unsubs[4] = mock_unsubs[0] + client._pending_unsubs[5] = mock_unsubs[1] + client._pending_unsubs[6] = mock_unsubs[2] + + # Start a disconnect. + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Trigger disconnect completion + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + if double_response: + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + + # All were cancelled + for mock in mock_subs: + assert mock.cancel.call_count == 1 + for mock in mock_unsubs: + assert mock.cancel.call_count == 1 + # All were removed + assert len(client._pending_subs) == 0 + assert len(client._pending_unsubs) == 0 + + @pytest.mark.it("Does not cancel or remove any pending publishes") + @pytest.mark.parametrize( + "double_response", + [ + pytest.param(False, id="Single Disconnect Response"), + pytest.param(True, id="Double Disconnect Response"), + ], + ) + async def test_no_cancel_pub(self, mocker, client, mock_paho, double_response): + # Set mocked pending Futures + mock_pubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_pubs[1] = mock_pubs[0] + client._pending_pubs[2] = mock_pubs[1] + client._pending_pubs[3] = mock_pubs[2] + + # Start a disconnect. + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Trigger disconnect completion + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + if double_response: + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + + # None were cancelled + for mock in mock_pubs: + assert mock.cancel.call_count == 0 + # None were removed + assert len(client._pending_pubs) == 3 + + @pytest.mark.it("Clears the completed network loop Future") + @pytest.mark.parametrize( + "double_response", + [ + pytest.param(False, id="Single Disconnect Response"), + pytest.param(True, id="Double Disconnect Response"), + ], + ) + async def test_network_loop(self, client, mock_paho, double_response): + # Require manual completion + mock_paho._manual_mode = True + + assert isinstance(client._network_loop, asyncio.Future) + network_loop_future = client._network_loop + assert not network_loop_future.done() + + # Start a disconnect. + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Trigger disconnect completion + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + if double_response: + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + + assert network_loop_future.done() + assert client._network_loop is None + + @pytest.mark.it( + "Raises CancelledError if cancelled while waiting for the Paho invocation to return" + ) + async def test_cancel_waiting_paho_invocation(self, client, mock_paho): + # Create a fake disconnect implementation that doesn't return right away + finish_disconnect = threading.Event() + waiting_on_paho = True + + def fake_disconnect(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_disconnect.wait() + waiting_on_paho = False + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.disconnect.side_effect = fake_disconnect + + # Start a disconnect task that will hang on Paho invocation + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + assert not disconnect_task.done() + # Paho invocation has not returned + assert waiting_on_paho + + # Cancel task + disconnect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await disconnect_task + + # Allow the fake implementation to finish + finish_disconnect.set() + await asyncio.sleep(0.1) + + # NOTE: This test differs from the ones seen in Pub/Sub/Unsub because pending operations + # in a disconnect don't indicate the same thing they with the others. Instead we hack the mock + # some more to prove the expected behavior + @pytest.mark.it("Raises CancelledError if cancelled while waiting for a response") + async def test_cancel_waiting_response(self, client, mock_paho): + paho_invoke_done = False + + def fake_disconnect(*args, **kwargs): + nonlocal paho_invoke_done + paho_invoke_done = True + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.disconnect.side_effect = fake_disconnect + + # Start a disconnect task + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # We are now waiting for a response + assert not disconnect_task.done() + assert paho_invoke_done + + # Cancel the disconnect task + disconnect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await disconnect_task + + # TODO: why is this needed + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await asyncio.sleep(0.1) + + +@pytest.mark.describe("MQTTClient - .disconnect() -- Client Connection Dropped") +class TestDisconnectWithClientConnectionDrop: + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_connection_dropped(client) + return client + + @pytest.mark.it("Invokes an MQTT disconnect via Paho") + async def test_paho_invocation(self, mocker, client, mock_paho): + assert mock_paho.disconnect.call_count == 0 + + await client.disconnect() + + assert mock_paho.disconnect.call_count == 1 + assert mock_paho.disconnect.call_args == mocker.call() + + @pytest.mark.it("Cancels and removes the reconnect daemon task if it is running") + async def test_reconnect_daemon(self, mocker, client): + # Set a fake daemon task + mock_task = mocker.MagicMock() + client._reconnect_daemon = mock_task + + await client.disconnect() + + assert mock_task.cancel.call_count == 1 + assert client._reconnect_daemon is None + + # NOTE: This is an invalid scenario. Connection being dropped implies there are + # no pending subscribes or unsubscribes + @pytest.mark.it("Does not cancel or remove any pending subscribes or unsubscribes") + async def test_pending_sub_unsub(self, mocker, client): + # Set mocked pending Futures + mock_subs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + mock_unsubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_subs[1] = mock_subs[0] + client._pending_subs[2] = mock_subs[1] + client._pending_subs[3] = mock_subs[2] + client._pending_unsubs[4] = mock_unsubs[0] + client._pending_unsubs[5] = mock_unsubs[1] + client._pending_unsubs[6] = mock_unsubs[2] + + await client.disconnect() + + # None were cancelled + for mock in mock_subs: + assert mock.cancel.call_count == 0 + for mock in mock_unsubs: + assert mock.cancel.call_count == 0 + # None were removed + assert len(client._pending_subs) == 3 + assert len(client._pending_unsubs) == 3 + + # NOTE: Unlike the above, this is a valid scenario. Publishes survive a connection drop. + @pytest.mark.it("Does not cancel or remove any pending publishes") + async def test_pending_pub(self, mocker, client): + # Set mocked pending Futures + mock_pubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_pubs[1] = mock_pubs[0] + client._pending_pubs[2] = mock_pubs[1] + client._pending_pubs[3] = mock_pubs[2] + + await client.disconnect() + + # None were cancelled + for mock in mock_pubs: + assert mock.cancel.call_count == 0 + # None were removed + assert len(client._pending_pubs) == 3 + + @pytest.mark.it("Leaves the client in a disconnected state") + async def test_state(self, client): + assert not client.is_connected() + + await client.disconnect() + + assert not client.is_connected() + + @pytest.mark.it("Clears the existing disconnection cause") + async def test_disconnection_cause(self, client): + assert client.previous_disconnection_cause() is not None + + await client.disconnect() + + assert client.previous_disconnection_cause() is None + + @pytest.mark.it("Does not wait for a response before returning") + async def test_return(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + # Attempt disconnect + disconnect_task = asyncio.create_task(client.disconnect()) + await disconnect_task + + @pytest.mark.it("Clears the completed network loop Future") + async def test_network_loop(self, mocker, client, mock_paho): + assert isinstance(client._network_loop, asyncio.Future) + network_loop_future = client._network_loop + # Connection Drop means that the loop task is done, but not cleared + assert network_loop_future.done() + + await client.disconnect() + + assert network_loop_future.done() + # Now the task has been cleared + assert client._network_loop is None + + @pytest.mark.it( + "Raises CancelledError if cancelled while waiting for the Paho invocation to return" + ) + async def test_cancel_waiting_paho_invocation(self, client, mock_paho): + # Create a fake disconnect implementation that doesn't return right away + finish_disconnect = threading.Event() + waiting_on_paho = True + + def fake_disconnect(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_disconnect.wait() + waiting_on_paho = False + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + return mqtt.MQTT_ERR_SUCCESS + + mock_paho.disconnect.side_effect = fake_disconnect + + # Start a disconnect task that will hang on Paho invocation + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + assert not disconnect_task.done() + # Paho invocation has not returned + assert waiting_on_paho + + # Cancel task + disconnect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await disconnect_task + + # Allow the fake implementation to finish + finish_disconnect.set() + await asyncio.sleep(0.1) + + +# NOTE: Because clients in Disconnected and Fresh states have the same behaviors during a connect, +# define a parent class that can be subclassed so tests don't have to be written twice. +class DisconnectWithClientFullyDisconnectedTests: + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_disconnected(client) + return client + + @pytest.mark.it("Does not invoke an MQTT disconnect via Paho") + async def test_paho_invocation(self, client, mock_paho): + assert mock_paho.disconnect.call_count == 0 + + await client.disconnect() + + assert mock_paho.disconnect.call_count == 0 + + # NOTE: This could happen due to a connect failure that starts the daemon, but leaves the + # client in a fully disconnected state. + @pytest.mark.it("Cancels and removes the reconnect daemon task if it is running") + async def test_reconnect_daemon(self, mocker, client): + # Set a fake daemon task + mock_task = mocker.MagicMock() + client._reconnect_daemon = mock_task + + await client.disconnect() + + assert mock_task.cancel.call_count == 1 + assert client._reconnect_daemon is None + + # NOTE: This is an invalid scenario. Being disconnected implies there are + # no pending subscribes or unsubscribes + @pytest.mark.it("Does not cancel or remove any pending subscribes or unsubscribes") + async def test_pending_sub_unsub(self, mocker, client): + # Set mocked pending Futures + mock_subs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + mock_unsubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_subs[1] = mock_subs[0] + client._pending_subs[2] = mock_subs[1] + client._pending_subs[3] = mock_subs[2] + client._pending_unsubs[4] = mock_unsubs[0] + client._pending_unsubs[5] = mock_unsubs[1] + client._pending_unsubs[6] = mock_unsubs[2] + + await client.disconnect() + + # None were cancelled + for mock in mock_subs: + assert mock.cancel.call_count == 0 + for mock in mock_unsubs: + assert mock.cancel.call_count == 0 + # None were removed + assert len(client._pending_subs) == 3 + assert len(client._pending_unsubs) == 3 + + # NOTE: Unlike the above, this is a valid scenario. Publishes survive a disconnect. + @pytest.mark.it("Does not cancel or remove any pending publishes") + async def test_pending_pub(self, mocker, client): + # Set mocked pending Futures + mock_pubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_pubs[1] = mock_pubs[0] + client._pending_pubs[2] = mock_pubs[1] + client._pending_pubs[3] = mock_pubs[2] + + await client.disconnect() + + # None were cancelled + for mock in mock_pubs: + assert mock.cancel.call_count == 0 + # None were removed + assert len(client._pending_pubs) == 3 + + @pytest.mark.it("Leaves the client in a disconnected state") + async def test_state(self, client): + assert not client.is_connected() + + await client.disconnect() + + assert not client.is_connected() + + @pytest.mark.it("Leaves the disconnection cause set to None") + async def test_disconnection_cause(self, client): + assert client.previous_disconnection_cause() is None + + await client.disconnect() + + assert client.previous_disconnection_cause() is None + + @pytest.mark.it("Does not wait for a response before returning") + async def test_return(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + # Attempt disconnect + disconnect_task = asyncio.create_task(client.disconnect()) + # No waiting for disconnect response trigger was required + await disconnect_task + + @pytest.mark.it("Does not alter the network loop Future") + async def test_network_loop(self, client): + assert client._network_loop is None + + await client.disconnect() + + assert client._network_loop is None + + +@pytest.mark.describe("MQTTClient - .disconnect() -- Client Already Disconnected") +class TestDisconnectWithClientDisconnected(DisconnectWithClientFullyDisconnectedTests): + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_disconnected(client) + return client + + +@pytest.mark.describe("MQTTClient - .disconnect() -- Client Fresh") +class TestDisconnectWithClientFresh(DisconnectWithClientFullyDisconnectedTests): + @pytest.fixture + async def client(self, fresh_client): + return fresh_client + + +@pytest.mark.describe("MQTTClient - OCCURRENCE: Unexpected Disconnect") +class TestUnexpectedDisconnect: + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client_set_connected(client) + return client + + @pytest.mark.it("Puts the client in a disconnected state") + async def test_state(self, client, mock_paho): + assert client.is_connected() + + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + assert not client.is_connected() + + @pytest.mark.it( + "Creates an MQTTError from the failed return code and sets it as the disconnection cause" + ) + async def test_disconnection_cause(self, client, mock_paho): + assert client.previous_disconnection_cause() is None + + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + cause = client.previous_disconnection_cause() + assert isinstance(cause, MQTTConnectionDroppedError) + assert cause.rc is mqtt.MQTT_ERR_CONN_LOST + + @pytest.mark.it("Does not alter the reconnect daemon") + async def test_reconnect_daemon(self, mocker, client, mock_paho): + client._auto_reconnect = True + mock_task = mocker.MagicMock() + client._reconnect_daemon = mock_task + + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + assert mock_task.cancel.call_count == 0 + assert client._reconnect_daemon is mock_task + + @pytest.mark.it("Cancels and removes all pending subscribes and unsubscribes") + async def test_cancel_sub_unsub(self, mocker, client, mock_paho): + # Set mocked pending Futures + mock_subs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + mock_unsubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_subs[1] = mock_subs[0] + client._pending_subs[2] = mock_subs[1] + client._pending_subs[3] = mock_subs[2] + client._pending_unsubs[4] = mock_unsubs[0] + client._pending_unsubs[5] = mock_unsubs[1] + client._pending_unsubs[6] = mock_unsubs[2] + + # Disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + # All were cancelled + for mock in mock_subs: + assert mock.cancel.call_count == 1 + for mock in mock_unsubs: + assert mock.cancel.call_count == 1 + # All were removed + assert len(client._pending_subs) == 0 + assert len(client._pending_unsubs) == 0 + + @pytest.mark.it("Does not cancel or remove any pending publishes") + async def test_no_cancel_pub(self, mocker, client, mock_paho): + # Set mocked pending Futures + mock_pubs = [mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()] + client._pending_pubs[1] = mock_pubs[0] + client._pending_pubs[2] = mock_pubs[1] + client._pending_pubs[3] = mock_pubs[2] + + # Disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + # None were cancelled + for mock in mock_pubs: + assert mock.cancel.call_count == 0 + # None were removed + assert len(client._pending_pubs) == 3 + + @pytest.mark.it("Does not remove the network loop Future, even though it completes") + async def test_network_loop(self, client, mock_paho): + assert client._network_loop is not None + assert isinstance(client._network_loop, asyncio.Future) + assert not client._network_loop.done() + network_loop_future = client._network_loop + + # Disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + assert client._network_loop is not None + assert client._network_loop.done() + assert client._network_loop is network_loop_future + + +@pytest.mark.describe("MQTTClient - Connection Lock") +class TestConnectionLock: + @pytest.mark.it("Waits for a pending connect task to finish before attempting a connect") + @pytest.mark.parametrize( + "pending_success", + [ + pytest.param(True, id="Pending connect succeeds"), + pytest.param(False, id="Pending connect fails"), + ], + ) + async def test_connect_pending_connect(self, client, mock_paho, pending_success): + # Require manual completion + mock_paho._manual_mode = True + assert mock_paho.connect.call_count == 0 + # Client is currently disconnected + assert not client.is_connected() + + # Attempt first connect + connect_task1 = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # Paho connect has been called but task1 is still pending + assert mock_paho.connect.call_count == 1 + assert not connect_task1.done() + # Start second attempt + connect_task2 = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # Paho connect has NOT been called an additional time and task2 is still pending + assert mock_paho.connect.call_count == 1 + assert not connect_task2.done() + + # Complete first connect + if pending_success: + mock_paho.trigger_on_connect(rc=mqtt.CONNACK_ACCEPTED) + else: + # Failure triggers both. Use Server Unavailable as an arbitrary reason for failure. + mock_paho.trigger_on_connect(rc=mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE) + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_REFUSED) + await asyncio.sleep(0.1) + assert connect_task1.done() + # Need to retrieve the exception to suppress error logging + if not pending_success: + with pytest.raises(MQTTConnectionFailedError): + connect_task1.result() + + if pending_success: + # Second connect was completed without invoking connect on Paho because it is + # already connected + assert client.is_connected() + assert connect_task2.done() + assert mock_paho.connect.call_count == 1 + assert client.is_connected() + else: + # Second connect has invoked connect on Paho and is waiting for completion + assert not client.is_connected() + assert not connect_task2.done() + assert mock_paho.connect.call_count == 2 + # Complete the second connect successfully + mock_paho.trigger_on_connect(rc=mqtt.CONNACK_ACCEPTED) + await connect_task2 + assert client.is_connected() + + # NOTE: Disconnect can't fail + @pytest.mark.it("Waits for a pending disconnect task to finish before attempting a connect") + async def test_connect_pending_disconnect(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + assert mock_paho.connect.call_count == 0 + assert mock_paho.disconnect.call_count == 0 + # Client must be connected for disconnect to pend + client_set_connected(client) + assert client.is_connected() + + # Attempt disconnect + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Paho disconnect has been called but task is still pending + assert mock_paho.disconnect.call_count == 1 + assert not disconnect_task.done() + # Attempt connect + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # Paho connect has NOT been called yet and task is still pending + assert mock_paho.connect.call_count == 0 + assert not connect_task.done() + + # Complete disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await asyncio.sleep(0.1) + assert disconnect_task.done() + assert not client.is_connected() + + # Connect task has now invoked Paho connect and is waiting for completion + assert not connect_task.done() + assert mock_paho.connect.call_count == 1 + # Complete the connect + mock_paho.trigger_on_connect(rc=mqtt.CONNACK_ACCEPTED) + await connect_task + assert client.is_connected() + + @pytest.mark.it("Waits for a pending connect task to finish before attempting a disconnect") + @pytest.mark.parametrize( + "pending_success", + [ + pytest.param(True, id="Pending connect succeeds"), + pytest.param(False, id="Pending connect fails"), + ], + ) + async def test_disconnect_pending_connect(self, client, mock_paho, pending_success): + # Require manual completion + mock_paho._manual_mode = True + assert mock_paho.disconnect.call_count == 0 + assert mock_paho.disconnect.call_count == 0 + # Paho has to be disconnected for connect to pend + assert not client.is_connected() + + # Attempt connect + connect_task = asyncio.create_task(client.connect()) + await asyncio.sleep(0.1) + # Paho connect has been called but task is still pending + assert mock_paho.connect.call_count == 1 + assert not connect_task.done() + # Attempt disconnect + disconnect_task = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Paho disconnect has NOT been called yet and task is still pending + assert mock_paho.disconnect.call_count == 0 + assert not disconnect_task.done() + + # Complete connect + if pending_success: + mock_paho.trigger_on_connect(rc=mqtt.CONNACK_ACCEPTED) + else: + # Failure triggers both. Use server unavailable as an arbitrary reason for failure + mock_paho.trigger_on_connect(rc=mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE) + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_REFUSED) + await asyncio.sleep(0.1) + assert connect_task.done() + # Need to retrieve the exception to suppress error logging + if not pending_success: + with pytest.raises(MQTTConnectionFailedError): + connect_task.result() + + if pending_success: + assert client.is_connected() + # Disconnect was invoked on Paho and is waiting for completion + assert not disconnect_task.done() + assert mock_paho.disconnect.call_count == 1 + # Complete the disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task + assert not client.is_connected() + else: + assert not client.is_connected() + # Disconnect was completed without invoking connect on Paho because it is + # already disconnected + assert disconnect_task.done() + assert mock_paho.disconnect.call_count == 0 + assert not client.is_connected() + + # NOTE: Disconnect can't fail + @pytest.mark.it("Waits for a pending disconnect task to finish before attempting a disconnect") + async def test_disconnect_pending_disconnect(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + assert mock_paho.disconnect.call_count == 0 + # Client is currently connected + client_set_connected(client) + assert client.is_connected() + + # Attempt first disconnect + disconnect_task1 = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Paho disconnect has been called but task 1 is still pending + assert mock_paho.disconnect.call_count == 1 + assert not disconnect_task1.done() + # Attempt second disconnect + disconnect_task2 = asyncio.create_task(client.disconnect()) + await asyncio.sleep(0.1) + # Paho disconnect has NOT been called an additional time and task2 is still pending + assert mock_paho.disconnect.call_count == 1 + assert not disconnect_task2.done() + + # Complete first disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await disconnect_task1 + assert not client.is_connected() + + # Second disconnect was completed without invoking disconnect on Paho because it is + # already disconnected + await disconnect_task2 + assert mock_paho.disconnect.call_count == 1 + assert not client.is_connected() + + +@pytest.mark.describe("MQTTClient - Reconnect Daemon") +class TestReconnectDaemon: + @pytest.fixture + async def client(self, fresh_client): + client = fresh_client + client._auto_reconnect = True + client._reconnect_interval = 2 + # Successfully connect + await client.connect() + assert client.is_connected() + # Reconnect Daemon is running + assert isinstance(client._reconnect_daemon, asyncio.Task) + return client + + @pytest.mark.it("Attempts to connect immediately after an unexpected disconnection") + async def test_unexpected_drop(self, mocker, client, mock_paho): + # Set connect to fail. This is kind of arbitrary - we just need it to do something + client.connect = mocker.AsyncMock(side_effect=MQTTConnectionFailedError) + assert client.connect.call_count == 0 + + # Drop the connection + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + # Connect was called by the daemon + assert client.connect.call_count == 1 + assert client.connect.call_args == mocker.call() + + @pytest.mark.it( + "Waits for the reconnect interval (in seconds) to try to connect again if the connect attempt fails non-fatally" + ) + async def test_reconnect_attempt_fails_nonfatal(self, mocker, client, mock_paho): + # Set connect to fail (nonfatal) + exc = MQTTConnectionFailedError(rc=mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE, fatal=False) + client.connect = mocker.AsyncMock(side_effect=exc) + assert client.connect.call_count == 0 + + # Drop the connection + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + # Connect was called by the daemon + assert client.connect.call_count == 1 + # Wait half the interval + await asyncio.sleep(client._reconnect_interval / 2) + # Connect has not been called again + assert client.connect.call_count == 1 + # Wait the rest of the interval + await asyncio.sleep(client._reconnect_interval / 2) + # Connect was attempted again + assert client.connect.call_count == 2 + + @pytest.mark.it("Ends reconnect attempts if the connect attempt fails fatally") + async def test_reconnect_attempt_fails_fatal(self, mocker, client, mock_paho): + # Set connect to fail (fatal) + exc = MQTTConnectionFailedError(message="Some fatal exc", fatal=True) + client.connect = mocker.AsyncMock(side_effect=exc) + assert client.connect.call_count == 0 + + # Drop the connect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + # Connect was called by the daemon + assert client.connect.call_count == 1 + # Daemon has exited + assert client._reconnect_daemon.done() + + @pytest.mark.it( + "Does not try again until the next unexpected disconnection if the connect attempt succeeds" + ) + async def test_reconnect_attempt_succeeds(self, mocker, client, mock_paho): + # Set connect to succeed + def fake_connect(): + client_set_connected(client) + + client.connect = mocker.AsyncMock(side_effect=fake_connect) + assert client.connect.call_count == 0 + + # Drop the connection + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + # Connect was called by the daemon + assert client.connect.call_count == 1 + # Wait for the interval + await asyncio.sleep(client._reconnect_interval) + # Connect was not attempted again + assert client.connect.call_count == 1 + + # Drop the connection again + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + # Connect was attempted again + assert client.connect.call_count == 2 + + @pytest.mark.it("Does not attempt to connect after an expected disconnection") + async def test_disconnect(self, mocker, client): + # Set connect to fail. This is kind of arbitrary - we just need it to do something + client.connect = mocker.AsyncMock(side_effect=MQTTConnectionFailedError) + assert client.connect.call_count == 0 + + # Disconnect + await client.disconnect() + await asyncio.sleep(0.1) + + # Connect was not called by the daemon + assert client.connect.call_count == 0 + + +@pytest.mark.describe("MQTTClient - .subscribe()") +class TestSubscribe: + @pytest.mark.it("Invokes an MQTT subscribe via Paho") + async def test_paho_invocation(self, mocker, client, mock_paho): + assert mock_paho.subscribe.call_count == 0 + + await client.subscribe(fake_topic) + + assert mock_paho.subscribe.call_count == 1 + assert mock_paho.subscribe.call_args == mocker.call(topic=fake_topic, qos=1) + + @pytest.mark.it("Raises a MQTTError if invoking Paho's subscribe returns a failed return code") + @pytest.mark.parametrize("failing_rc", subscribe_failed_rc_params) + async def test_fail_status(self, client, mock_paho, failing_rc): + mock_paho._subscribe_rc = failing_rc + + with pytest.raises(MQTTError) as e_info: + await client.subscribe(fake_topic) + assert e_info.value.rc == failing_rc + + @pytest.mark.it("Allows any exceptions raised by invoking Paho's subscribe to propagate") + async def test_fail_paho_invocation_raises(self, client, mock_paho, arbitrary_exception): + mock_paho.subscribe.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + await client.subscribe(fake_topic) + + @pytest.mark.it( + "Waits to return until Paho receives a matching response if the subscribe invocation succeeded" + ) + async def test_matching_completion(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a subscribe. It won't complete + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.5) + assert not subscribe_task.done() + + # Trigger subscribe completion + mock_paho.trigger_on_subscribe(mock_paho._last_mid) + await subscribe_task + + @pytest.mark.it("Does not return if Paho receives a non-matching response") + async def test_nonmatching_completion(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start two subscribes. They won't complete + subscribe_task1 = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + subscribe_task1_mid = mock_paho._last_mid + subscribe_task2 = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + subscribe_task2_mid = mock_paho._last_mid + assert subscribe_task1_mid != subscribe_task2_mid + await asyncio.sleep(0.5) + assert not subscribe_task1.done() + assert not subscribe_task2.done() + + # Trigger subscribe completion for one of them + mock_paho.trigger_on_subscribe(subscribe_task2_mid) + # The corresponding task completes + await subscribe_task2 + # The other does not + assert not subscribe_task1.done() + + # Complete the other one + mock_paho.trigger_on_subscribe(subscribe_task1_mid) + await subscribe_task1 + + @pytest.mark.it("Can handle responses received before or after Paho invocation returns") + @pytest.mark.parametrize("early_ack", early_ack_params) + async def test_early_ack(self, client, mock_paho, early_ack): + mock_paho._early_ack = early_ack + await client.subscribe(fake_topic) + # If this doesn't hang, the test passes + + @pytest.mark.it( + "Retains pending subscribe tracking information only until receiving a response" + ) + async def test_pending(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a subscribe. It won't complete + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Pending subscribe is tracked + mid = mock_paho._last_mid + assert mid in client._pending_subs + + # Trigger subscribe completion + mock_paho.trigger_on_subscribe(mid) + await subscribe_task + + # Pending subscribe is no longer tracked + assert mid not in client._pending_subs + + @pytest.mark.it( + "Does not establish pending subscribe tracking information if invoking Paho's subscribe returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", subscribe_failed_rc_params) + async def test_pending_fail_status(self, client, mock_paho, failing_rc): + mock_paho._subscribe_rc = failing_rc + + with pytest.raises(MQTTError): + await client.subscribe(fake_topic) + + assert len(client._pending_subs) == 0 + + @pytest.mark.it( + "Does not establish pending subscribe tracking information if invoking Paho's subscribe raises an exception" + ) + async def test_pending_fail_paho_raise(self, client, mock_paho, arbitrary_exception): + mock_paho.subscribe.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + await client.subscribe(fake_topic) + + assert len(client._pending_subs) == 0 + + @pytest.mark.it( + "Raises CancelledError if cancelled while waiting for the Paho subscribe invocation to return" + ) + async def test_cancel_waiting_paho_invocation(self, client, mock_paho): + # Create a fake subscribe implementation that doesn't return right away + finish_subscribe = threading.Event() + waiting_on_paho = False + + def fake_subscribe(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_subscribe.wait() + waiting_on_paho = False + + mock_paho.subscribe.side_effect = fake_subscribe + assert len(client._pending_subs) == 0 + + # Start a subscribe task that will hang on Paho invocation + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + assert not subscribe_task.done() + # Paho invocation has not returned + assert waiting_on_paho + assert len(client._pending_subs) == 0 + + # Cancel task + subscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await subscribe_task + + # Allow the fake implementation to finish + finish_subscribe.set() + + @pytest.mark.it("Raises CancelledError if cancelled while waiting for a response") + async def test_cancel_waiting_response(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + assert len(client._pending_subs) == 0 + + # Start an subscribe task and cancel it. + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + assert not subscribe_task.done() + # The sub pending means we received a mid from the invocation + # i.e. we are now waiting for a response + assert len(client._pending_subs) == 1 + + subscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await subscribe_task + + # NOTE: There's no subscribe tracking information if cancelled while waiting for invocation + # as we don't have a MID yet. + @pytest.mark.it( + "Clears pending subscribe tracking information if cancelled while waiting for a response" + ) + async def test_pending_cancelled(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a subscribe. It won't complete + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Pending subscribe is tracked + mid = mock_paho._last_mid + assert mid in client._pending_subs + + # Cancel + subscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await subscribe_task + + # Pending subscribe is no longer tracked + assert mid not in client._pending_subs + + @pytest.mark.it( + "Raises CancelledError if the pending subscribe is cancelled by a disconnect attempt" + ) + async def test_cancelled_by_disconnect(self, client, mock_paho): + client_set_connected(client) + # Require manual completion + mock_paho._manual_mode = True + + # Start a subscribe. It won't complete + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Do a disconnect + disconnect_task = asyncio.create_task(client.disconnect()) + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await asyncio.sleep(0.1) + + with pytest.raises(asyncio.CancelledError): + await subscribe_task + + await disconnect_task + + @pytest.mark.it( + "Raises CancelledError if the pending subscribe is cancelled by an unexpected disconnect" + ) + async def test_cancelled_by_unexpected_disconnect(self, client, mock_paho): + client_set_connected(client) + # Require manual completion + mock_paho._manual_mode = True + + # Start a subscribe. It won't complete + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Trigger unexpected disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + with pytest.raises(asyncio.CancelledError): + await subscribe_task + + @pytest.mark.it( + "Can handle receiving a response for a subscribe that was cancelled after it was in-flight" + ) + async def test_ack_after_cancel(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a subscribe. It won't complete + subscribe_task = asyncio.create_task(client.subscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Pending subscribe is tracked + mid = mock_paho._last_mid + assert mid in client._pending_subs + + # Cancel + subscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await subscribe_task + + # Pending subscribe is no longer tracked + assert mid not in client._pending_subs + + # Trigger subscribe response after cancellation + mock_paho.trigger_on_subscribe(mid) + await asyncio.sleep(0.1) + + # No failure, no problem + + +@pytest.mark.describe("MQTTClient - .unsubscribe()") +class TestUnsubscribe: + @pytest.mark.it("Invokes an MQTT unsubscribe via Paho") + async def test_paho_invocation(self, mocker, client, mock_paho): + assert mock_paho.unsubscribe.call_count == 0 + + await client.unsubscribe(fake_topic) + + assert mock_paho.unsubscribe.call_count == 1 + assert mock_paho.unsubscribe.call_args == mocker.call(topic=fake_topic) + + @pytest.mark.it( + "Raises a MQTTError if invoking Paho's unsubscribe returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", unsubscribe_failed_rc_params) + async def test_fail_status(self, client, mock_paho, failing_rc): + mock_paho._unsubscribe_rc = failing_rc + + with pytest.raises(MQTTError) as e_info: + await client.unsubscribe(fake_topic) + assert e_info.value.rc == failing_rc + + @pytest.mark.it("Allows any exceptions raised by invoking Paho's unsubscribe to propagate") + async def test_fail_paho_invocation_raises(self, client, mock_paho, arbitrary_exception): + mock_paho.unsubscribe.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + await client.unsubscribe(fake_topic) + + @pytest.mark.it( + "Waits to return until Paho receives a matching response if the unsubscribe invocation succeeded" + ) + async def test_waits_for_completion(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a unsubscribe. It won't complete + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.5) + assert not unsubscribe_task.done() + + # Trigger unsubscribe completion + mock_paho.trigger_on_unsubscribe(mock_paho._last_mid) + await unsubscribe_task + + @pytest.mark.it("Does not return if Paho receives a non-matching response") + async def test_nonmatching_completion(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start two unsubscribes. They won't complete + unsubscribe_task1 = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + unsubscribe_task1_mid = mock_paho._last_mid + unsubscribe_task2 = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + unsubscribe_task2_mid = mock_paho._last_mid + assert unsubscribe_task1_mid != unsubscribe_task2_mid + await asyncio.sleep(0.5) + assert not unsubscribe_task1.done() + assert not unsubscribe_task2.done() + + # Trigger unsubscribe completion for one of them + mock_paho.trigger_on_unsubscribe(unsubscribe_task2_mid) + # The corresponding task completes + await unsubscribe_task2 + # The other does not + assert not unsubscribe_task1.done() + + # Complete the other one + mock_paho.trigger_on_unsubscribe(unsubscribe_task1_mid) + await unsubscribe_task1 + + @pytest.mark.it("Can handle responses received before or after Paho invocation returns") + @pytest.mark.parametrize("early_ack", early_ack_params) + async def test_early_ack(self, client, mock_paho, early_ack): + mock_paho._early_ack = early_ack + await client.unsubscribe(fake_topic) + # If this doesn't hang, the test passes + + @pytest.mark.it( + "Retains pending unsubscribe tracking information only until receiving a response" + ) + async def test_pending(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a unsubscribe. It won't complete + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Pending unsubscribe is tracked + mid = mock_paho._last_mid + assert mid in client._pending_unsubs + + # Trigger unsubscribe completion + mock_paho.trigger_on_unsubscribe(mid) + await unsubscribe_task + + # Pending unsubscribe is no longer tracked + assert mid not in client._pending_unsubs + + @pytest.mark.it( + "Does not establish pending unsubscribe tracking information if invoking Paho's unsubscribe returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", unsubscribe_failed_rc_params) + async def test_pending_fail_status(self, client, mock_paho, failing_rc): + mock_paho._unsubscribe_rc = failing_rc + + with pytest.raises(MQTTError): + await client.unsubscribe(fake_topic) + + assert len(client._pending_unsubs) == 0 + + @pytest.mark.it( + "Does not establish pending unsubscribe tracking information if invoking Paho's unsubscribe raises an exception" + ) + async def test_pending_fail_paho_raise(self, client, mock_paho, arbitrary_exception): + mock_paho.unsubscribe.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + await client.unsubscribe(fake_topic) + + assert len(client._pending_unsubs) == 0 + + @pytest.mark.it( + "Raises CancelledError if cancelled while waiting for the Paho unsubscribe invocation to return" + ) + async def test_cancel_waiting_paho_invocation(self, client, mock_paho): + # Create a fake unsubscribe implementation that doesn't return right away + finish_unsubscribe = threading.Event() + waiting_on_paho = False + + def fake_unsubscribe(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_unsubscribe.wait() + waiting_on_paho = False + + mock_paho.unsubscribe.side_effect = fake_unsubscribe + assert len(client._pending_unsubs) == 0 + + # Start a subscribe task that will hang on Paho invocation + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + assert not unsubscribe_task.done() + # Paho invocation has not returned + assert waiting_on_paho + assert len(client._pending_unsubs) == 0 + + # Cancel task + unsubscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await unsubscribe_task + + # Allow the fake implementation to finish + finish_unsubscribe.set() + + @pytest.mark.it("Raises CancelledError if cancelled while waiting for a response") + async def test_cancel_waiting_response(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + assert len(client._pending_subs) == 0 + + # Start an unsubscribe task and cancel it. + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + assert not unsubscribe_task.done() + # The unsub pending means we received a mid from the invocation + # i.e. we are now waiting for a response + assert len(client._pending_unsubs) == 1 + + unsubscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await unsubscribe_task + + # NOTE: There's no unsubscribe tracking information if cancelled while waiting for invocation + # as we don't have a MID yet. + @pytest.mark.it("Clears pending unsubscribe tracking information if cancelled") + async def test_pending_cancelled(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a unsubscribe. It won't complete + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Pending unsubscribe is tracked + mid = mock_paho._last_mid + assert mid in client._pending_unsubs + + # Cancel + unsubscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await unsubscribe_task + + # Pending unsubscribe is no longer tracked + assert mid not in client._pending_unsubs + + @pytest.mark.it( + "Raises CancelledError if the pending unsubscribe is cancelled by a disconnect attempt" + ) + async def test_cancelled_by_disconnect(self, client, mock_paho): + client_set_connected(client) + # Require manual completion + mock_paho._manual_mode = True + + # Start a unsubscribe. It won't complete + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Do a disconnect + disconnect_task = asyncio.create_task(client.disconnect()) + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_SUCCESS) + await asyncio.sleep(0.1) + + with pytest.raises(asyncio.CancelledError): + await unsubscribe_task + + await disconnect_task + + @pytest.mark.it( + "Raises CancelledError if the pending unsubscribe is cancelled by an unexpected disconnect" + ) + async def test_cancelled_by_unexpected_disconnect(self, client, mock_paho): + client_set_connected(client) + # Require manual completion + mock_paho._manual_mode = True + + # Start a unsubscribe. It won't complete + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Trigger unexpected disconnect + mock_paho.trigger_on_disconnect(rc=mqtt.MQTT_ERR_CONN_LOST) + await asyncio.sleep(0.1) + + with pytest.raises(asyncio.CancelledError): + await unsubscribe_task + + @pytest.mark.it( + "Can handle receiving a response for an unsubscribe that was cancelled after it was in-flight" + ) + async def test_ack_after_cancel(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a unsubscribe. It won't complete + unsubscribe_task = asyncio.create_task(client.unsubscribe(fake_topic)) + await asyncio.sleep(0.1) + + # Pending unsubscribe is tracked + mid = mock_paho._last_mid + assert mid in client._pending_unsubs + + # Cancel + unsubscribe_task.cancel() + with pytest.raises(asyncio.CancelledError): + await unsubscribe_task + + # Pending subscribe is no longer tracked + assert mid not in client._pending_unsubs + + # Trigger unsubscribe response after cancellation + mock_paho.trigger_on_unsubscribe(mid) + await asyncio.sleep(0.1) + + # No failure, no problem + + +@pytest.mark.describe("MQTTClient - .publish()") +class TestPublish: + @pytest.mark.it("Invokes an MQTT publish via Paho") + async def test_paho_invocation(self, mocker, client, mock_paho): + assert mock_paho.publish.call_count == 0 + + await client.publish(fake_topic, fake_payload) + + assert mock_paho.publish.call_count == 1 + assert mock_paho.publish.call_args == mocker.call( + topic=fake_topic, payload=fake_payload, qos=1 + ) + + # NOTE: MQTT_ERR_NO_CONN is not a failure for publish + @pytest.mark.it("Raises a MQTTError if invoking Paho's publish returns a failed return code") + @pytest.mark.parametrize("failing_rc", publish_failed_rc_params) + async def test_fail_status(self, client, mock_paho, failing_rc): + mock_paho._publish_rc = failing_rc + + with pytest.raises(MQTTError) as e_info: + await client.publish(fake_topic, fake_payload) + assert e_info.value.rc == failing_rc + + @pytest.mark.it("Allows any exceptions raised by invoking Paho's publish to propagate") + async def test_fail_paho_invocation_raises(self, client, mock_paho, arbitrary_exception): + mock_paho.publish.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + await client.publish(fake_topic, fake_payload) + + @pytest.mark.it( + "Waits to return until Paho receives a matching response if the publish invocation succeeded" + ) + async def test_matching_completion_success(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + mock_paho._publish_rc = mqtt.MQTT_ERR_SUCCESS + + # Start a publish. It won't complete + publish_task = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.5) + assert not publish_task.done() + + # Trigger publish completion + mock_paho.trigger_on_publish(mock_paho._last_mid) + await publish_task + + @pytest.mark.it( + "Waits to return until Paho receives a matching response (after connect established) if the publish invocation returned 'Not Connected'" + ) + async def test_matching_completion_no_conn(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + mock_paho._publish_rc = mqtt.MQTT_ERR_NO_CONN + + # Start a publish. It won't complete + publish_task = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.5) + assert not publish_task.done() + + # NOTE: Yeah, the test refers to after the connect is established, but there's no need + # to bring connection state into play here. Point is, after becoming connected, Paho will + # automatically re-publish, and when a response is received it will trigger completion. + + # Trigger publish completion + mock_paho.trigger_on_publish(mock_paho._last_mid) + await publish_task + + @pytest.mark.it("Does not return if Paho receives a non-matching response") + async def test_nonmatching_completion(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start two publishes. They won't complete + publish_task1 = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.1) + publish_task1_mid = mock_paho._last_mid + publish_task2 = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.1) + publish_task2_mid = mock_paho._last_mid + assert publish_task1_mid != publish_task2_mid + await asyncio.sleep(0.5) + assert not publish_task1.done() + assert not publish_task2.done() + + # Trigger publish completion for one of them + mock_paho.trigger_on_publish(publish_task2_mid) + # The corresponding task completes + await publish_task2 + # The other does not + assert not publish_task1.done() + + # Complete the other one + mock_paho.trigger_on_publish(publish_task1_mid) + await publish_task1 + + @pytest.mark.it("Can handle responses received before or after Paho invocation returns") + @pytest.mark.parametrize("early_ack", early_ack_params) + async def test_early_ack(self, client, mock_paho, early_ack): + mock_paho._early_ack = early_ack + await client.publish(fake_topic, fake_payload) + # If this doesn't hang, the test passes + + @pytest.mark.it("Retains pending publish tracking information only until receiving a response") + async def test_pending(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a publish. It won't complete + publish_task = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.1) + + # Pending publish is tracked + mid = mock_paho._last_mid + assert mid in client._pending_pubs + + # Trigger publish completion + mock_paho.trigger_on_publish(mid) + await publish_task + + # Pending publish is no longer tracked + assert mid not in client._pending_pubs + + @pytest.mark.it( + "Does not establish pending publish tracking information if invoking Paho's publish returns a failed return code" + ) + @pytest.mark.parametrize("failing_rc", publish_failed_rc_params) + async def test_pending_fail_status(self, client, mock_paho, failing_rc): + mock_paho._publish_rc = failing_rc + + with pytest.raises(MQTTError): + await client.publish(fake_topic, fake_payload) + + assert len(client._pending_pubs) == 0 + + @pytest.mark.it( + "Does not establish pending publish tracking information if invoking Paho's publish raises an exception" + ) + async def test_pending_fail_paho_raise(self, client, mock_paho, arbitrary_exception): + mock_paho.publish.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + await client.publish(fake_topic, fake_payload) + + assert len(client._pending_subs) == 0 + + @pytest.mark.it( + "Raises CancelledError if cancelled while waiting for the Paho invocation to return" + ) + async def test_cancel_waiting_paho_invocation( + self, + client, + mock_paho, + ): + # Create a fake publish implementation that doesn't return right away + finish_publish = threading.Event() + waiting_on_paho = False + + def fake_publish(*args, **kwargs): + nonlocal waiting_on_paho + waiting_on_paho = True + finish_publish.wait() + waiting_on_paho = False + + mock_paho.publish.side_effect = fake_publish + assert len(client._pending_pubs) == 0 + + # Start a publish task that will hang on Paho invocation + publish_task = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.1) + assert not publish_task.done() + # Paho invocation has not returned + assert waiting_on_paho + assert len(client._pending_pubs) == 0 + + # Cancel task + publish_task.cancel() + with pytest.raises(asyncio.CancelledError): + await publish_task + + # Allow the fake implementation to finish + finish_publish.set() + + @pytest.mark.it("Raises CancelledError if cancelled while waiting for a response") + async def test_cancel_waiting_response(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + assert len(client._pending_pubs) == 0 + + # Start a publish task and cancel it. + publish_task = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.1) + assert not publish_task.done() + # The pub pending means we received a mid from the invocation + # i.e. we are now waiting for a response + assert len(client._pending_pubs) == 1 + + publish_task.cancel() + with pytest.raises(asyncio.CancelledError): + await publish_task + + # NOTE: There's no publish tracking information if cancelled while waiting for invocation + # as we don't have a mid yet. + @pytest.mark.it("Clears pending publish tracking information if cancelled") + async def test_pending_cancelled(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a publish. It won't complete + publish_task = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.1) + + # Pending publish is tracked + mid = mock_paho._last_mid + assert mid in client._pending_pubs + + # Cancel + publish_task.cancel() + with pytest.raises(asyncio.CancelledError): + await publish_task + + # Pending publish is no longer tracked + assert mid not in client._pending_pubs + + @pytest.mark.it( + "Can handle receiving a response for a publish that was cancelled after it was in-flight" + ) + async def test_ack_after_cancel(self, client, mock_paho): + # Require manual completion + mock_paho._manual_mode = True + + # Start a publish. It won't complete + publish_task = asyncio.create_task(client.publish(fake_topic, fake_payload)) + await asyncio.sleep(0.1) + + # Pending publish is tracked + mid = mock_paho._last_mid + assert mid in client._pending_pubs + + # Cancel + publish_task.cancel() + with pytest.raises(asyncio.CancelledError): + await publish_task + + # Pending publish is no longer tracked + assert mid not in client._pending_pubs + + # Trigger publish response after cancellation + mock_paho.trigger_on_publish(mid) + await asyncio.sleep(0.1) + + # No failure, no problem + + +# NOTE: Because so much of the logic of message receives is internal to Paho, to test more detail +# would really just be testing mocks. So we're just going to test the handlers/callbacks provided +# and assume the logic regarding when to use them is correct. As a result, the descriptions of +# these tests somewhat overstate the content of the test, because to truly test what would be +# described with a mocked Paho, would just be testing mocks and side effects. +@pytest.mark.describe("MQTTClient - OCCURRENCE: Message Received") +class TestMessageReceived: + @pytest.mark.it( + "Puts the received message in the default message queue if no matching topic filter is defined" + ) + async def test_no_filter(self, client): + assert client._incoming_messages.empty() + + message = mqtt.MQTTMessage(mid=1) + client._mqtt_client.on_message(client, None, message) + await asyncio.sleep(0.1) + + assert not client._incoming_messages.empty() + assert client._incoming_messages.qsize() == 1 + item = await client._incoming_messages.get() + assert item is message + + @pytest.mark.it( + "Puts the received message in a filtered queue if a matching topic filter is defined" + ) + async def test_filter(self, client, mock_paho): + topic1 = fake_topic + topic2 = "even/faker/topic" + + # Get callbacks and queues for filters + client.add_incoming_message_filter(topic1) + topic1_incoming_messages = client._incoming_filtered_messages[topic1] + assert mock_paho.message_callback_add.call_count == 1 + topic1_callback = mock_paho.message_callback_add.call_args[0][1] + + client.add_incoming_message_filter(topic2) + topic2_incoming_messages = client._incoming_filtered_messages[topic2] + assert mock_paho.message_callback_add.call_count == 2 + topic2_callback = mock_paho.message_callback_add.call_args[0][1] + + assert topic1_incoming_messages.empty() + assert topic2_incoming_messages.empty() + assert client._incoming_messages.empty() + + # Receive Messages + message1 = mqtt.MQTTMessage(mid=1) + topic1_callback(client, None, message1) + message2 = mqtt.MQTTMessage(mid=2) + topic2_callback(client, None, message2) + await asyncio.sleep(0.1) + + # Messages were put in correct queue + assert client._incoming_messages.empty() + + assert not topic1_incoming_messages.empty() + assert topic1_incoming_messages.qsize() == 1 + item1 = await topic1_incoming_messages.get() + assert item1 is message1 + + assert not topic2_incoming_messages.empty() + assert topic2_incoming_messages.qsize() == 1 + item2 = await topic2_incoming_messages.get() + assert item2 is message2 diff --git a/tests/unit/test_mqtt_topic_iothub.py b/tests/unit/test_mqtt_topic_iothub.py new file mode 100644 index 000000000..4c018ef32 --- /dev/null +++ b/tests/unit/test_mqtt_topic_iothub.py @@ -0,0 +1,812 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import logging +from azure.iot.device import mqtt_topic_iothub + +logging.basicConfig(level=logging.DEBUG) + +# NOTE: All tests (that require it) are parametrized with multiple values for URL encoding. +# This is to show that the URL encoding is done correctly - not all URL encoding encodes +# the same way. +# +# For URL encoding, we must always test the ' ' and '/' characters specifically, in addition +# to a generic URL encoding value (e.g. $, #, etc.) +# +# For URL decoding, we must always test the '+' character specifically, in addition to +# a generic URL encoded value (e.g. %24, %23, etc.) +# +# Please also always test that provided values are converted to strings in order to ensure +# that they can be URL encoded without error. +# +# PLEASE DO THESE TESTS FOR EVEN CASES WHERE THOSE CHARACTERS SHOULD NOT OCCUR, FOR SAFETY. + + +@pytest.mark.describe(".get_c2d_topic_for_subscribe()") +class TestGetC2DTopicForSubscribe: + @pytest.mark.it("Returns the topic for subscribing to C2D messages from IoTHub") + def test_returns_topic(self): + device_id = "my_device" + expected_topic = "devices/my_device/messages/devicebound/#" + topic = mqtt_topic_iothub.get_c2d_topic_for_subscribe(device_id) + assert topic == expected_topic + + # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have + # to follow that and not do URL encoding for safety. As a result, some of the values used in + # this test would actually be invalid in production due to character restrictions on the Hub + # that exist to prevent Hub from breaking due to a lack of URL decoding. + # If Hub does begin to support robust URL encoding for safety, this test can easily be switched + # to show that URL encoding DOES work. + @pytest.mark.it("Does NOT URL encode the device_id when generating the topic") + @pytest.mark.parametrize( + "device_id, expected_topic", + [ + pytest.param( + "my$device", "devices/my$device/messages/devicebound/#", id="id contains '$'" + ), + pytest.param( + "my device", "devices/my device/messages/devicebound/#", id="id contains ' '" + ), + pytest.param( + "my/device", "devices/my/device/messages/devicebound/#", id="id contains '/'" + ), + ], + ) + def test_url_encoding(self, device_id, expected_topic): + topic = mqtt_topic_iothub.get_c2d_topic_for_subscribe(device_id) + assert topic == expected_topic + + @pytest.mark.it("Converts the device_id to string when generating the topic") + def test_str_conversion(self): + device_id = 2000 + expected_topic = "devices/2000/messages/devicebound/#" + topic = mqtt_topic_iothub.get_c2d_topic_for_subscribe(device_id) + assert topic == expected_topic + + +@pytest.mark.describe(".get_input_topic_for_subscribe()") +class TestGetInputTopicForSubscribe: + @pytest.mark.it("Returns the topic for subscribing to Input messages from IoTHub") + def test_returns_topic(self): + device_id = "my_device" + module_id = "my_module" + expected_topic = "devices/my_device/modules/my_module/inputs/#" + topic = mqtt_topic_iothub.get_input_topic_for_subscribe(device_id, module_id) + assert topic == expected_topic + + # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have + # to follow that and not do URL encoding for safety. As a result, some of the values used in + # this test would actually be invalid in production due to character restrictions on the Hub + # that exist to prevent Hub from breaking due to a lack of URL decoding. + # If Hub does begin to support robust URL encoding for safety, this test can easily be switched + # to show that URL encoding DOES work. + @pytest.mark.it("URL encodes the device_id and module_id when generating the topic") + @pytest.mark.parametrize( + "device_id, module_id, expected_topic", + [ + pytest.param( + "my$device", + "my$module", + "devices/my$device/modules/my$module/inputs/#", + id="ids contain '$'", + ), + pytest.param( + "my device", + "my module", + "devices/my device/modules/my module/inputs/#", + id="ids contain ' '", + ), + pytest.param( + "my/device", + "my/module", + "devices/my/device/modules/my/module/inputs/#", + id="ids contain '/'", + ), + ], + ) + def test_url_encoding(self, device_id, module_id, expected_topic): + topic = mqtt_topic_iothub.get_input_topic_for_subscribe(device_id, module_id) + assert topic == expected_topic + + @pytest.mark.it("Converts the device_id and module_id to string when generating the topic") + def test_str_conversion(self): + device_id = 2000 + module_id = 4000 + expected_topic = "devices/2000/modules/4000/inputs/#" + topic = mqtt_topic_iothub.get_input_topic_for_subscribe(device_id, module_id) + assert topic == expected_topic + + +@pytest.mark.describe(".get_direct_method_request_topic_for_subscribe()") +class TestGetMethodTopicForSubscribe: + @pytest.mark.it("Returns the topic for subscribing to methods from IoTHub") + def test_returns_topic(self): + topic = mqtt_topic_iothub.get_direct_method_request_topic_for_subscribe() + assert topic == "$iothub/methods/POST/#" + + +@pytest.mark.describe("get_twin_response_topic_for_subscribe()") +class TestGetTwinResponseTopicForSubscribe: + @pytest.mark.it("Returns the topic for subscribing to twin response from IoTHub") + def test_returns_topic(self): + topic = mqtt_topic_iothub.get_twin_response_topic_for_subscribe() + assert topic == "$iothub/twin/res/#" + + +@pytest.mark.describe("get_twin_patch_topic_for_subscribe()") +class TestGetTwinPatchTopicForSubscribe: + @pytest.mark.it("Returns the topic for subscribing to twin patches from IoTHub") + def test_returns_topic(self): + topic = mqtt_topic_iothub.get_twin_patch_topic_for_subscribe() + assert topic == "$iothub/twin/PATCH/properties/desired/#" + + +@pytest.mark.describe(".get_telemetry_topic_for_publish()") +class TestGetTelemetryTopicForPublish: + @pytest.mark.it("Returns the topic for sending telemetry to IoTHub") + @pytest.mark.parametrize( + "device_id, module_id, expected_topic", + [ + pytest.param("my_device", None, "devices/my_device/messages/events/", id="Device"), + pytest.param( + "my_device", + "my_module", + "devices/my_device/modules/my_module/messages/events/", + id="Module", + ), + ], + ) + def test_returns_topic(self, device_id, module_id, expected_topic): + topic = mqtt_topic_iothub.get_telemetry_topic_for_publish(device_id, module_id) + assert topic == expected_topic + + # NOTE: It SHOULD do URL encoding, but Hub doesn't currently support URL decoding, so we have + # to follow that and not do URL encoding for safety. As a result, some of the values used in + # this test would actually be invalid in production due to character restrictions on the Hub + # that exist to prevent Hub from breaking due to a lack of URL decoding. + # If Hub does begin to support robust URL encoding for safety, this test can easily be switched + # to show that URL encoding DOES work. + @pytest.mark.it("URL encodes the device_id and module_id when generating the topic") + @pytest.mark.parametrize( + "device_id, module_id, expected_topic", + [ + pytest.param( + "my$device", + None, + "devices/my$device/messages/events/", + id="Device, id contains '$'", + ), + pytest.param( + "my device", + None, + "devices/my device/messages/events/", + id="Device, id contains ' '", + ), + pytest.param( + "my/device", + None, + "devices/my/device/messages/events/", + id="Device, id contains '/'", + ), + pytest.param( + "my$device", + "my$module", + "devices/my$device/modules/my$module/messages/events/", + id="Module, ids contain '$'", + ), + pytest.param( + "my device", + "my module", + "devices/my device/modules/my module/messages/events/", + id="Module, ids contain ' '", + ), + pytest.param( + "my/device", + "my/module", + "devices/my/device/modules/my/module/messages/events/", + id="Module, ids contain '/'", + ), + ], + ) + def test_url_encoding(self, device_id, module_id, expected_topic): + topic = mqtt_topic_iothub.get_telemetry_topic_for_publish(device_id, module_id) + assert topic == expected_topic + + @pytest.mark.it("Converts the device_id and module_id to string when generating the topic") + @pytest.mark.parametrize( + "device_id, module_id, expected_topic", + [ + pytest.param(2000, None, "devices/2000/messages/events/", id="Device"), + pytest.param(2000, 4000, "devices/2000/modules/4000/messages/events/", id="Module"), + ], + ) + def test_str_conversion(self, device_id, module_id, expected_topic): + topic = mqtt_topic_iothub.get_telemetry_topic_for_publish(device_id, module_id) + assert topic == expected_topic + + +@pytest.mark.describe(".get_direct_method_response_topic_for_publish()") +class TestGetMethodTopicForPublish: + @pytest.mark.it("Returns the topic for sending a direct method response to IoTHub") + @pytest.mark.parametrize( + "request_id, status, expected_topic", + [ + pytest.param("1", "200", "$iothub/methods/res/200/?$rid=1", id="Successful result"), + pytest.param( + "475764", "500", "$iothub/methods/res/500/?$rid=475764", id="Failure result" + ), + ], + ) + def test_returns_topic(self, request_id, status, expected_topic): + topic = mqtt_topic_iothub.get_direct_method_response_topic_for_publish(request_id, status) + assert topic == expected_topic + + @pytest.mark.it("URL encodes provided values when generating the topic") + @pytest.mark.parametrize( + "request_id, status, expected_topic", + [ + pytest.param( + "invalid#request?id", + "invalid$status", + "$iothub/methods/res/invalid%24status/?$rid=invalid%23request%3Fid", + id="Standard URL Encoding", + ), + pytest.param( + "invalid request id", + "invalid status", + "$iothub/methods/res/invalid%20status/?$rid=invalid%20request%20id", + id="URL Encoding of ' ' character", + ), + pytest.param( + "invalid/request/id", + "invalid/status", + "$iothub/methods/res/invalid%2Fstatus/?$rid=invalid%2Frequest%2Fid", + id="URL Encoding of '/' character", + ), + ], + ) + def test_url_encoding(self, request_id, status, expected_topic): + topic = mqtt_topic_iothub.get_direct_method_response_topic_for_publish(request_id, status) + assert topic == expected_topic + + @pytest.mark.it("Converts the provided values to strings when generating the topic") + def test_str_conversion(self): + request_id = 1 + status = 200 + expected_topic = "$iothub/methods/res/200/?$rid=1" + topic = mqtt_topic_iothub.get_direct_method_response_topic_for_publish(request_id, status) + assert topic == expected_topic + + +@pytest.mark.describe(".get_twin_request_topic_for_publish()") +class TestGetTwinRequestTopicForPublish: + @pytest.mark.it("Returns topic for sending a get twin request to IoTHub") + def test_returns_topic(self): + request_id = "3226c2f7-3d30-425c-b83b-0c34335f8220" + expected_topic = "$iothub/twin/GET/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220" + topic = mqtt_topic_iothub.get_twin_request_topic_for_publish(request_id) + assert topic == expected_topic + + @pytest.mark.it("URL encodes 'request_id' parameter when generating the topic") + @pytest.mark.parametrize( + "request_id, expected_topic", + [ + pytest.param( + "invalid$request?id", + "$iothub/twin/GET/?$rid=invalid%24request%3Fid", + id="Standard URL Encoding", + ), + pytest.param( + "invalid request id", + "$iothub/twin/GET/?$rid=invalid%20request%20id", + id="URL Encoding of ' ' character", + ), + pytest.param( + "invalid/request/id", + "$iothub/twin/GET/?$rid=invalid%2Frequest%2Fid", + id="URL Encoding of '/' character", + ), + ], + ) + def test_url_encoding(self, request_id, expected_topic): + topic = mqtt_topic_iothub.get_twin_request_topic_for_publish(request_id) + assert topic == expected_topic + + @pytest.mark.it("Converts 'request_id' parameter to string when generating the topic") + def test_str_conversion(self): + request_id = 4000 + expected_topic = "$iothub/twin/GET/?$rid=4000" + topic = mqtt_topic_iothub.get_twin_request_topic_for_publish(request_id) + assert topic == expected_topic + + +@pytest.mark.describe(".get_twin_patch_topic_for_publish()") +class TestGetTwinPatchTopicForPublish: + @pytest.mark.it("Returns topic for sending a twin patch to IoTHub") + def test_returns_topic(self): + request_id = "5002b415-af16-47e9-b89c-8680e01b502f" + expected_topic = ( + "$iothub/twin/PATCH/properties/reported/?$rid=5002b415-af16-47e9-b89c-8680e01b502f" + ) + topic = mqtt_topic_iothub.get_twin_patch_topic_for_publish(request_id) + assert topic == expected_topic + + @pytest.mark.it("URL encodes 'request_id' parameter when generating the topic") + @pytest.mark.parametrize( + "request_id, expected_topic", + [ + pytest.param( + "invalid$request?id", + "$iothub/twin/PATCH/properties/reported/?$rid=invalid%24request%3Fid", + id="Standard URL Encoding", + ), + pytest.param( + "invalid request id", + "$iothub/twin/PATCH/properties/reported/?$rid=invalid%20request%20id", + id="URL Encoding of ' ' character", + ), + pytest.param( + "invalid/request/id", + "$iothub/twin/PATCH/properties/reported/?$rid=invalid%2Frequest%2Fid", + id="URL Encoding of '/' character", + ), + ], + ) + def test_url_encoding(self, request_id, expected_topic): + topic = mqtt_topic_iothub.get_twin_patch_topic_for_publish(request_id) + assert topic == expected_topic + + @pytest.mark.it("Converts 'request_id' parameter to string when generating the topic") + def test_str_conversion(self): + request_id = 4000 + expected_topic = "$iothub/twin/PATCH/properties/reported/?$rid=4000" + topic = mqtt_topic_iothub.get_twin_patch_topic_for_publish(request_id) + assert topic == expected_topic + + +@pytest.mark.describe(".insert_message_properties_in_topic()") +class TestEncodeMessagePropertiesInTopic: + @pytest.fixture(params=["C2D Message", "Input Message"]) + def message_topic(self, request): + if request.param == "C2D Message": + return "devices/fake_device/messages/events/" + else: + return "devices/fake_device/modules/fake_module/messages/events/" + + @pytest.mark.it( + "Returns a new version of the given topic string that contains the provided properties as key/value pairs" + ) + @pytest.mark.parametrize( + "system_properties, custom_properties, expected_encoding", + [ + pytest.param({}, {}, "", id="No Properties"), + pytest.param( + {"sys_prop1": "value1", "sys_prop2": "value2"}, + {}, + "sys_prop1=value1&sys_prop2=value2", + id="System Properties Only", + ), + pytest.param( + {}, + {"cust_prop1": "value3", "cust_prop2": "value4"}, + "cust_prop1=value3&cust_prop2=value4", + id="Custom Properties Only", + ), + pytest.param( + {"sys_prop1": "value1", "sys_prop2": "value2"}, + {"cust_prop1": "value3", "cust_prop2": "value4"}, + "sys_prop1=value1&sys_prop2=value2&cust_prop1=value3&cust_prop2=value4", + id="System Properties and Custom Properties", + ), + ], + ) + def test_adds_properties( + self, message_topic, system_properties, custom_properties, expected_encoding + ): + expected_topic = message_topic + expected_encoding + encoded_topic = mqtt_topic_iothub.insert_message_properties_in_topic( + message_topic, system_properties, custom_properties + ) + assert encoded_topic == expected_topic + + @pytest.mark.it( + "URL encodes keys and values in the provided properties when adding them to the topic string" + ) + @pytest.mark.parametrize( + "system_properties, custom_properties, expected_encoding", + [ + pytest.param( + {"$.mid": "message#id", "$.ce": "utf-#"}, + {}, + "%24.mid=message%23id&%24.ce=utf-%23", + id="System Properties Only (Standard URL Encoding)", + ), + pytest.param( + {}, + {"cu$tom1": "value#3", "cu$tom2": "value#4"}, + "cu%24tom1=value%233&cu%24tom2=value%234", + id="Custom Properties Only (Standard URL Encoding)", + ), + pytest.param( + {"$.mid": "message#id", "$.ce": "utf-#"}, + {"cu$tom1": "value#3", "cu$tom2": "value#4"}, + "%24.mid=message%23id&%24.ce=utf-%23&cu%24tom1=value%233&cu%24tom2=value%234", + id="System Properties and Custom Properties (Standard URL Encoding)", + ), + pytest.param( + {"m id": "message id", "c e": "utf 8"}, + {}, + "m%20id=message%20id&c%20e=utf%208", + id="System Properties Only (URL Encoding of ' ' Character)", + ), + pytest.param( + {}, + {"custom 1": "value 1", "custom 2": "value 2"}, + "custom%201=value%201&custom%202=value%202", + id="Custom Properties Only (URL Encoding of ' ' Character)", + ), + pytest.param( + {"m id": "message id", "c e": "utf 8"}, + {"custom 1": "value 1", "custom 2": "value 2"}, + "m%20id=message%20id&c%20e=utf%208&custom%201=value%201&custom%202=value%202", + id="System Properties and Custom Properties (URL Encoding of ' ' Character)", + ), + pytest.param( + {"m/id": "message/id", "c/e": "utf/8"}, + {}, + "m%2Fid=message%2Fid&c%2Fe=utf%2F8", + id="System Properties Only (URL Encoding of '/' Character)", + ), + pytest.param( + {}, + {"custom/1": "value/1", "custom/2": "value/2"}, + "custom%2F1=value%2F1&custom%2F2=value%2F2", + id="Custom Properties Only (URL Encoding of '/' Character)", + ), + pytest.param( + {"m/id": "message/id", "c/e": "utf/8"}, + {"custom/1": "value/1", "custom/2": "value/2"}, + "m%2Fid=message%2Fid&c%2Fe=utf%2F8&custom%2F1=value%2F1&custom%2F2=value%2F2", + id="System Properties and Custom Properties (URL Encoding of '/' Character)", + ), + ], + ) + def test_url_encoding( + self, message_topic, system_properties, custom_properties, expected_encoding + ): + expected_topic = message_topic + expected_encoding + encoded_topic = mqtt_topic_iothub.insert_message_properties_in_topic( + message_topic, system_properties, custom_properties + ) + assert encoded_topic == expected_topic + + +@pytest.mark.describe(".extract_properties_from_message_topic()") +class TestExtractPropertiesFromMessageTopic: + @pytest.fixture(params=["C2D Message", "Input Message"]) + def message_topic_base(self, request): + if request.param == "C2D Message": + return "devices/fake_device/messages/devicebound/" + else: + return "devices/fake_device/modules/fake_module/inputs/fake_input/" + + @pytest.mark.it( + "Returns a dictionary mapping of all key/value pairs contained within the given topic string" + ) + @pytest.mark.parametrize( + "property_string, expected_property_dict", + [ + pytest.param("", {}, id="No properties"), + pytest.param( + "key1=value1&key2=value2&key3=value3", + {"key1": "value1", "key2": "value2", "key3": "value3"}, + id="Some Properties", + ), + ], + ) + def test_returns_map(self, message_topic_base, property_string, expected_property_dict): + topic = message_topic_base + property_string + properties = mqtt_topic_iothub.extract_properties_from_message_topic(topic) + assert properties == expected_property_dict + + @pytest.mark.it("URL decodes the key/value pairs extracted from the topic") + @pytest.mark.parametrize( + "property_string, expected_property_dict", + [ + pytest.param( + "%24.key1=value%231&%24.key2=value%232", + {"$.key1": "value#1", "$.key2": "value#2"}, + id="Standard URL Decoding", + ), + pytest.param( + "key%201=value%201&key%202=value%202", + {"key 1": "value 1", "key 2": "value 2"}, + id="URL Encoding of ' ' Character", + ), + pytest.param( + "key%2F1=value%2F1&key%2F2=value%2F2", + {"key/1": "value/1", "key/2": "value/2"}, + id="URL Encoding of '/' Character", + ), + ], + ) + def test_url_decoding(self, message_topic_base, property_string, expected_property_dict): + topic = message_topic_base + property_string + properties = mqtt_topic_iothub.extract_properties_from_message_topic(topic) + assert properties == expected_property_dict + + @pytest.mark.it("Supports empty string in properties") + @pytest.mark.parametrize( + "property_string, expected_property_dict", + [ + pytest.param("=value1", {"": "value1"}, id="Empty String Key"), + pytest.param("key1=", {"key1": ""}, id="Empty String Value"), + ], + ) + def test_empty_string(self, message_topic_base, property_string, expected_property_dict): + topic = message_topic_base + property_string + properties = mqtt_topic_iothub.extract_properties_from_message_topic(topic) + assert properties == expected_property_dict + + @pytest.mark.it( + "Maps the extracted key to value of empty string if there is a key with no corresponding value present" + ) + def test_key_only(self, message_topic_base): + property_string = "key1&key2&key3=value" + expected_property_dict = {"key1": "", "key2": "", "key3": "value"} + topic = message_topic_base + property_string + properties = mqtt_topic_iothub.extract_properties_from_message_topic(topic) + assert properties == expected_property_dict + + @pytest.mark.it( + "Raises a ValueError if the provided topic is not a C2D topic or an Input Message topic" + ) + @pytest.mark.parametrize( + "topic", + [ + pytest.param("not a topic", id="Not a topic"), + pytest.param( + "$iothub/twin/res/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5", + id="Topic of wrong type", + ), + pytest.param( + "devices/fake_device/messages/devicebnd/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmessages%2Fdevicebound", + id="Malformed C2D topic", + ), + pytest.param( + "devices/fake_device/modules/fake_module/inutps/fake_input/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2Ffake_device%2Fmodules%2Ffake_module%2Finputs%2Ffake_input", + id="Malformed input message topic", + ), + ], + ) + def test_bad_topic(self, topic): + with pytest.raises(ValueError): + mqtt_topic_iothub.extract_properties_from_message_topic(topic) + + +@pytest.mark.describe(".extract_name_from_direct_method_request_topic()") +class TestExtractNameFromMethodRequestTopic: + @pytest.mark.it("Returns the method name from a method topic") + def test_valid_direct_method_topic(self): + topic = "$iothub/methods/POST/fake_method/?$rid=1" + expected_method_name = "fake_method" + + assert ( + mqtt_topic_iothub.extract_name_from_direct_method_request_topic(topic) + == expected_method_name + ) + + @pytest.mark.it("URL decodes the returned method name") + @pytest.mark.parametrize( + "topic, expected_method_name", + [ + pytest.param( + "$iothub/methods/POST/fake%24method/?$rid=1", + "fake$method", + id="Standard URL Decoding", + ), + pytest.param( + "$iothub/methods/POST/fake+method/?$rid=1", + "fake+method", + id="Does NOT decode '+' character", + ), + ], + ) + def test_url_decodes_value(self, topic, expected_method_name): + assert ( + mqtt_topic_iothub.extract_name_from_direct_method_request_topic(topic) + == expected_method_name + ) + + @pytest.mark.it("Raises a ValueError if the provided topic is not a method topic") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("not a topic", id="Not a topic"), + pytest.param( + "devices/fake_device/modules/fake_module/inputs/fake_input", + id="Topic of wrong type", + ), + pytest.param("$iothub/methdos/POST/fake_method/?$rid=1", id="Malformed topic"), + ], + ) + def test_invalid_direct_method_topic(self, topic): + with pytest.raises(ValueError): + mqtt_topic_iothub.extract_name_from_direct_method_request_topic(topic) + + +@pytest.mark.describe(".extract_request_id_from_direct_method_request_topic()") +class TestExtractRequestIdFromMethodRequestTopic: + @pytest.mark.it("Returns the request id from a method topic") + def test_valid_direct_method_topic(self): + topic = "$iothub/methods/POST/fake_method/?$rid=1" + expected_request_id = "1" + + assert ( + mqtt_topic_iothub.extract_request_id_from_direct_method_request_topic(topic) + == expected_request_id + ) + + @pytest.mark.it("URL decodes the returned value") + @pytest.mark.parametrize( + "topic, expected_request_id", + [ + pytest.param( + "$iothub/methods/POST/fake_method/?$rid=fake%24request%2Fid", + "fake$request/id", + id="Standard URL Decoding", + ), + pytest.param( + "$iothub/methods/POST/fake_method/?$rid=fake+request+id", + "fake+request+id", + id="Does NOT decode '+' character", + ), + ], + ) + def test_url_decodes_value(self, topic, expected_request_id): + assert ( + mqtt_topic_iothub.extract_request_id_from_direct_method_request_topic(topic) + == expected_request_id + ) + + @pytest.mark.it("Raises a ValueError if the provided topic is not a method topic") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("not a topic", id="Not a topic"), + pytest.param( + "devices/fake_device/modules/fake_module/inputs/fake_input", + id="Topic of wrong type", + ), + pytest.param("$iothub/methdos/POST/fake_method/?$rid=1", id="Malformed topic"), + ], + ) + def test_invalid_direct_method_topic(self, topic): + with pytest.raises(ValueError): + mqtt_topic_iothub.extract_request_id_from_direct_method_request_topic(topic) + + @pytest.mark.it("Raises a ValueError if the provided topic does not contain a request id") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("$iothub/methods/POST/fake_method/?$mid=1", id="No request id key"), + pytest.param("$iothub/methods/POST/fake_method/?$rid", id="No request id value"), + pytest.param("$iothub/methods/POST/fake_method/?$rid=", id="Empty request id value"), + ], + ) + def test_no_request_id(self, topic): + with pytest.raises(ValueError): + mqtt_topic_iothub.extract_request_id_from_direct_method_request_topic(topic) + + +@pytest.mark.describe(".extract_status_code_from_twin_response_topic()") +class TestExtractStatusCodeFromTwinResponseTopic: + @pytest.mark.it("Returns the status from a twin response topic") + def test_valid_twin_response_topic(self): + topic = "$iothub/twin/res/200/?rid=1" + expected_status = "200" + + assert ( + mqtt_topic_iothub.extract_status_code_from_twin_response_topic(topic) == expected_status + ) + + @pytest.mark.it("URL decodes the returned value") + @pytest.mark.parametrize( + "topic, expected_status", + [ + pytest.param("$iothub/twin/res/%24%24%24/?rid=1", "$$$", id="Standard URL decoding"), + pytest.param( + "$iothub/twin/res/invalid+status/?rid=1", + "invalid+status", + id="Does NOT decode '+' character", + ), + ], + ) + def test_url_decode(self, topic, expected_status): + assert ( + mqtt_topic_iothub.extract_status_code_from_twin_response_topic(topic) == expected_status + ) + + @pytest.mark.it("Raises a ValueError if the provided topic is not a twin response topic") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("not a topic", id="Not a topic"), + pytest.param( + "devices/fake_device/modules/fake_module/inputs/fake_input", + id="Topic of wrong type", + ), + pytest.param("$iothub/twn/res/200?rid=1", id="Malformed topic"), + ], + ) + def test_invalid_twin_response_topic(self, topic): + with pytest.raises(ValueError): + mqtt_topic_iothub.extract_status_code_from_twin_response_topic(topic) + + +@pytest.mark.describe(".extract_request_id_from_twin_response_topic()") +class TestExtractRequestIdFromTwinResponseTopic: + @pytest.mark.it("Returns the request id from a twin response topic") + def test_valid_twin_response_topic(self): + topic = "$iothub/twin/res/200/?$rid=1" + expected_request_id = "1" + + assert ( + mqtt_topic_iothub.extract_request_id_from_twin_response_topic(topic) + == expected_request_id + ) + + @pytest.mark.it("URL decodes the returned value") + @pytest.mark.parametrize( + "topic, expected_request_id", + [ + pytest.param( + "$iothub/twin/res/200/?$rid=fake%24request%2Fid", + "fake$request/id", + id="Standard URL Decoding", + ), + pytest.param( + "$iothub/twin/res/200/?$rid=fake+request+id", + "fake+request+id", + id="Does NOT decode '+' character", + ), + ], + ) + def test_url_decodes_value(self, topic, expected_request_id): + assert ( + mqtt_topic_iothub.extract_request_id_from_twin_response_topic(topic) + == expected_request_id + ) + + @pytest.mark.it("Raises a ValueError if the provided topic is not a twin response topic") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("not a topic", id="Not a topic"), + pytest.param( + "devices/fake_device/modules/fake_module/inputs/fake_input", + id="Topic of wrong type", + ), + pytest.param("$iothub/twn/res/200?$rid=1", id="Malformed topic"), + ], + ) + def test_invalid_twin_response_topic(self, topic): + with pytest.raises(ValueError): + mqtt_topic_iothub.extract_request_id_from_twin_response_topic(topic) + + @pytest.mark.it("Raises a ValueError if the provided topic does not contain a request id") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("$iothub/twin/res/200/?$mid=1", id="No request id key"), + pytest.param("$iothub/twin/res/200/?$rid", id="No request id value"), + pytest.param("$iothub/twin/res/200/?$rid=", id="Empty request id value"), + ], + ) + def test_no_request_id(self, topic): + with pytest.raises(ValueError): + mqtt_topic_iothub.extract_request_id_from_twin_response_topic(topic) diff --git a/tests/unit/provisioning/pipeline/test_mqtt_topic_provisioning.py b/tests/unit/test_mqtt_topic_provisioning.py similarity index 60% rename from tests/unit/provisioning/pipeline/test_mqtt_topic_provisioning.py rename to tests/unit/test_mqtt_topic_provisioning.py index beba95e7a..7a539a52e 100644 --- a/tests/unit/provisioning/pipeline/test_mqtt_topic_provisioning.py +++ b/tests/unit/test_mqtt_topic_provisioning.py @@ -6,7 +6,7 @@ import pytest import logging -from azure.iot.device.provisioning.pipeline import mqtt_topic_provisioning +from azure.iot.device import mqtt_topic_provisioning logging.basicConfig(level=logging.DEBUG) @@ -26,11 +26,11 @@ # PLEASE DO THESE TESTS FOR EVEN CASES WHERE THOSE CHARACTERS SHOULD NOT OCCUR, FOR SAFETY. -@pytest.mark.describe(".get_register_topic_for_subscribe()") -class TestGetRegisterTopicForSubscribe(object): - @pytest.mark.it("Returns the topic for subscribing to registration responses from DPS") +@pytest.mark.describe(".get_response_topic_for_subscribe()") +class TestGetResponseTopicForSubscribe(object): + @pytest.mark.it("Returns the topic for subscribing to responses from DPS") def test_returns_topic(self): - topic = mqtt_topic_provisioning.get_register_topic_for_subscribe() + topic = mqtt_topic_provisioning.get_response_topic_for_subscribe() assert topic == "$dps/registrations/res/#" @@ -78,13 +78,13 @@ def test_string_conversion(self): assert topic == expected_topic -class TestGetQueryTopicForPublish(object): - @pytest.mark.it("Returns the topic for publishing query requests to DPS") +class TestGetStatusQueryTopicForPublish(object): + @pytest.mark.it("Returns the topic for publishing status query requests to DPS") def test_returns_topic(self): request_id = "3226c2f7-3d30-425c-b83b-0c34335f8220" operation_id = "4.79f33f69d8eb3870.da2d9251-3097-43e9-b81c-782718485ce7" expected_topic = "$dps/registrations/GET/iotdps-get-operationstatus/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220&operationId=4.79f33f69d8eb3870.da2d9251-3097-43e9-b81c-782718485ce7" - topic = mqtt_topic_provisioning.get_query_topic_for_publish(request_id, operation_id) + topic = mqtt_topic_provisioning.get_status_query_topic_for_publish(request_id, operation_id) assert topic == expected_topic @pytest.mark.it("URL encodes the request id and operation id when generating the topic") @@ -112,7 +112,7 @@ def test_returns_topic(self): ], ) def test_url_encoding(self, request_id, operation_id, expected_topic): - topic = mqtt_topic_provisioning.get_query_topic_for_publish(request_id, operation_id) + topic = mqtt_topic_provisioning.get_status_query_topic_for_publish(request_id, operation_id) assert topic == expected_topic @pytest.mark.it("Converts the request id and operation id to string when generating the topic") @@ -122,125 +122,116 @@ def test_string_conversion(self): expected_topic = ( "$dps/registrations/GET/iotdps-get-operationstatus/?$rid=1234&operationId=4567" ) - topic = mqtt_topic_provisioning.get_query_topic_for_publish(request_id, operation_id) + topic = mqtt_topic_provisioning.get_status_query_topic_for_publish(request_id, operation_id) assert topic == expected_topic -@pytest.mark.describe(".is_dps_response_topic()") -class TestIsDpsResponseTopic(object): - @pytest.mark.it("Returns True if the topic is a DPS response topic") - @pytest.mark.parametrize( - "topic", - [ - pytest.param( - "$dps/registrations/res/200/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220", - id="Successful (200) response", - ), - pytest.param( - "$dps/registrations/res/202/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220&retry-after=3", - id="Retry-after (202) response", - ), - pytest.param( - "$dps/registrations/res/401/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220", - id="Unauthorized (401) response", - ), - ], - ) - def test_is_dps_response_topic(self, topic): - assert mqtt_topic_provisioning.is_dps_response_topic(topic) +@pytest.mark.describe(".extract_properties_from_response_topic()") +class TestExtractPropertiesFromResponseTopic(object): + @pytest.fixture + def topic_base(self): + return "$dps/registrations/res/200/?" - @pytest.mark.it("Returns False if the topic is not a DPS response topic") + @pytest.mark.it( + "Returns a dictionary mapping of all key/value pairs contained within the given topic string" + ) @pytest.mark.parametrize( - "topic", + "property_string, expected_dict", [ - pytest.param("not a topic", id="Not a topic"), + pytest.param("", {}, id="No properties"), pytest.param( - "$dps/registrations/PUT/iotdps-register/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220", - id="Topic of wrong type", - ), - pytest.param( - "$dps/resigtrations/res/200/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220", - id="Malformed topic", + "key1=value1&key2=value2&key3=value3", + {"key1": "value1", "key2": "value2", "key3": "value3"}, + id="Some properties", ), ], ) - def test_is_not_dps_response_topic(self, topic): - assert not mqtt_topic_provisioning.is_dps_response_topic(topic) - + def test_returns_properties(self, topic_base, property_string, expected_dict): + topic = topic_base + property_string + assert ( + mqtt_topic_provisioning.extract_properties_from_response_topic(topic) == expected_dict + ) -@pytest.mark.describe(".extract_properties_from_dps_response_topic()") -class TestExtractPropertiesFromDpsResponseTopic(object): - @pytest.mark.it("Returns the properties from a valid DPS response topic as a dictionary") + @pytest.mark.it("URL decodes the key/value pairs extracted from the response topic") @pytest.mark.parametrize( - "topic, expected_dict", + "property_string, expected_dict", [ pytest.param( - "$dps/registrations/res/200/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220", - {"rid": "3226c2f7-3d30-425c-b83b-0c34335f8220"}, - id="Successful (200) response", + "%24.key1=value%231&%24.key2=value%232", + {"$.key1": "value#1", "$.key2": "value#2"}, + id="Standard URL Decoding", ), pytest.param( - "$dps/registrations/res/202/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220&retry-after=3", - {"rid": "3226c2f7-3d30-425c-b83b-0c34335f8220", "retry-after": "3"}, - id="Retry-after (202) response", + "key%201=value%201&key%202=value%202", + {"key 1": "value 1", "key 2": "value 2"}, + id="URL Encoding of ' ' Character", ), pytest.param( - "$dps/registrations/res/401/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220", - {"rid": "3226c2f7-3d30-425c-b83b-0c34335f8220"}, - id="Unauthorized (401) response", + "key%2F1=value%2F1&key%2F2=value%2F2", + {"key/1": "value/1", "key/2": "value/2"}, + id="URL Encoding of '/' Character", ), pytest.param( - "$dps/registrations/res/200/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220&foo=value1&bar=value2&buzz=value3", - { - "rid": "3226c2f7-3d30-425c-b83b-0c34335f8220", - "foo": "value1", - "bar": "value2", - "buzz": "value3", - }, - id="Arbitrary number of properties in response", + "key+1=request+id", + {"key+1": "request+id"}, + id="Does NOT decode '+' character", ), ], ) - def test_returns_properties(self, topic, expected_dict): + def test_url_decode_properties(self, topic_base, property_string, expected_dict): + topic = topic_base + property_string assert ( - mqtt_topic_provisioning.extract_properties_from_dps_response_topic(topic) - == expected_dict + mqtt_topic_provisioning.extract_properties_from_response_topic(topic) == expected_dict ) - @pytest.mark.it("URL decodes properties extracted from the DPS response topic") + @pytest.mark.it("Supports empty string in properties") @pytest.mark.parametrize( - "topic, expected_dict", + "property_string, expected_dict", [ - pytest.param( - "$dps/registrations/res/200/?$rid=request%3Fid", - {"rid": "request?id"}, - id="Standard URL decoding", - ), - pytest.param( - "$dps/registrations/res/200/?$rid=request+id", - {"rid": "request+id"}, - id="Ddoes NOT decode '+' character", - ), + pytest.param("=value1", {"": "value1"}, id="Empty String Key"), + pytest.param("key1=", {"key1": ""}, id="Empty String Value"), ], ) - def test_url_decode_properties(self, topic, expected_dict): + def test_empty_string(self, topic_base, property_string, expected_dict): + topic = topic_base + property_string assert ( - mqtt_topic_provisioning.extract_properties_from_dps_response_topic(topic) - == expected_dict + mqtt_topic_provisioning.extract_properties_from_response_topic(topic) == expected_dict ) @pytest.mark.it( - "Raises ValueError if there are duplicate property keys in the DPS response topic" + "Maps the extracted key to value of empty string if there is a key with no corresponding value present" ) - def test_duplicate_keys(self): - topic = "$dps/registrations/res/200/?$rid=3226c2f7-3d30-425c-b83b-0c34335f8220&rid=something-else" + def test_key_only(self, topic_base): + property_string = "key1&key2&key3=value" + expected_dict = {"key1": "", "key2": "", "key3": "value"} + topic = topic_base + property_string + assert ( + mqtt_topic_provisioning.extract_properties_from_response_topic(topic) == expected_dict + ) + + @pytest.mark.it("Raises a ValueError if the provided topic is not DPS response topic") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("not a topic", id="Not a topic"), + pytest.param( + "$iothub/twin/res/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5", + id="Topic of wrong type", + ), + pytest.param( + "$dps/registrtaisons/res/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5", + id="Malformed response topic", + ), + ], + ) + def test_bad_topic(self, topic): with pytest.raises(ValueError): - mqtt_topic_provisioning.extract_properties_from_dps_response_topic(topic) + mqtt_topic_provisioning.extract_properties_from_response_topic(topic) -@pytest.mark.describe(".extract_status_code_from_dps_response_topic()") +@pytest.mark.describe(".extract_status_code_from_response_topic()") class TestExtractStatusCodeFromDpsResponseTopic(object): - @pytest.mark.it("Returns the status code from a valid DPS response topic") + @pytest.mark.it("Returns the status code from a DPS response topic") @pytest.mark.parametrize( "topic, expected_status", [ @@ -263,11 +254,11 @@ class TestExtractStatusCodeFromDpsResponseTopic(object): ) def test_returns_status(self, topic, expected_status): assert ( - mqtt_topic_provisioning.extract_status_code_from_dps_response_topic(topic) + mqtt_topic_provisioning.extract_status_code_from_response_topic(topic) == expected_status ) - @pytest.mark.it("URL decodes the status code extracted from DPS response topic") + @pytest.mark.it("URL decodes the returned value") @pytest.mark.parametrize( "topic, expected_status", [ @@ -285,6 +276,25 @@ def test_returns_status(self, topic, expected_status): ) def test_url_decode(self, topic, expected_status): assert ( - mqtt_topic_provisioning.extract_status_code_from_dps_response_topic(topic) + mqtt_topic_provisioning.extract_status_code_from_response_topic(topic) == expected_status ) + + @pytest.mark.it("Raises a ValueError if the provided topic is not a DPS response topic") + @pytest.mark.parametrize( + "topic", + [ + pytest.param("not a topic", id="Not a topic"), + pytest.param( + "$iothub/twin/res/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5", + id="Topic of wrong type", + ), + pytest.param( + "$dps/registrtaisons/res/200/?$rid=d9d7ce4d-3be9-498b-abde-913b81b880e5", + id="Malformed response topic", + ), + ], + ) + def test_invalid_response_topic(self, topic): + with pytest.raises(ValueError): + mqtt_topic_provisioning.extract_properties_from_response_topic(topic) diff --git a/tests/unit/test_provisioning_mqtt_client.py b/tests/unit/test_provisioning_mqtt_client.py new file mode 100644 index 000000000..3476062bf --- /dev/null +++ b/tests/unit/test_provisioning_mqtt_client.py @@ -0,0 +1,2288 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import json +import uuid + +import pytest +import ssl +import sys +import time + +import urllib +from pytest_lazyfixture import lazy_fixture +from dev_utils import custom_mock +from azure.iot.device.provisioning_mqtt_client import ( + ProvisioningMQTTClient, + DEFAULT_RECONNECT_INTERVAL, + DEFAULT_POLLING_INTERVAL, + DEFAULT_TIMEOUT_INTERVAL, +) + +from azure.iot.device import config, constant, user_agent +from azure.iot.device import exceptions as exc +from azure.iot.device import mqtt_client as mqtt +from azure.iot.device import request_response as rr +from azure.iot.device import mqtt_topic_provisioning as mqtt_topic +from azure.iot.device import sastoken as st + + +FAKE_REGISTER_REQUEST_ID = "fake_register_request_id" +FAKE_POLLING_REQUEST_ID = "fake_polling_request_id" +FAKE_REGISTRATION_ID = "fake_registration_id" +FAKE_ID_SCOPE = "fake_idscope" +FAKE_HOSTNAME = "fake.hostname" +FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" +FAKE_EXPIRY = str(int(time.time()) + 3600) +FAKE_URI = "fake/resource/location" +FAKE_STATUS = "assigned" +FAKE_SUB_STATUS = "OK" +FAKE_OPERATION_ID = "fake_operation_id" +FAKE_DEVICE_ID = "MyDevice" +FAKE_ASSIGNED_HUB = "MyIoTHub" + +# Parametrization +# TODO: expand this when we know more about what exceptions get raised from MQTTClient +mqtt_connect_exceptions = [ + pytest.param(mqtt.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] +mqtt_disconnect_exceptions = [ + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception") +] +mqtt_publish_exceptions = [ + pytest.param(mqtt.MQTTError(rc=5), id="MQTTError"), + pytest.param(ValueError(), id="ValueError"), + pytest.param(TypeError(), id="TypeError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] +mqtt_subscribe_exceptions = [ + # NOTE: CancelledError is here because network failure can cancel a subscribe + # without explicit invocation of cancel on the subscribe + pytest.param(mqtt.MQTTError(rc=5), id="MQTTError"), + pytest.param(asyncio.CancelledError(), id="CancelledError (Not initiated)"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] +mqtt_unsubscribe_exceptions = [ + # NOTE: CancelledError is here because network failure can cancel an unsubscribe + # without explicit invocation of cancel on the unsubscribe + pytest.param(mqtt.MQTTError(rc=5), id="MQTTError"), + pytest.param(asyncio.CancelledError(), id="CancelledError (Not initiated)"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] + + +# Fixtures +@pytest.fixture +def sastoken(): + sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( + resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=FAKE_EXPIRY + ) + return st.SasToken(sastoken_str) + + +@pytest.fixture +def mock_sastoken_provider(mocker, sastoken): + provider = mocker.MagicMock(spec=st.SasTokenProvider) + provider.get_current_sastoken.return_value = sastoken + # Use a HangingAsyncMock so that it isn't constantly returning + provider.wait_for_new_sastoken = custom_mock.HangingAsyncMock() + provider.wait_for_new_sastoken.return_value = sastoken + # NOTE: Technically, this mock just always returns the same SasToken, + # even after an "update", but for the purposes of testing at this level, + # it doesn't matter + return provider + + +@pytest.fixture +def client_config(): + """Defaults to DPS Configuration. Required values only. + Customize in test if you need specific options""" + client_config = config.ProvisioningClientConfig( + registration_id=FAKE_REGISTRATION_ID, + hostname=FAKE_HOSTNAME, + id_scope=FAKE_ID_SCOPE, + ssl_context=ssl.SSLContext(), + ) + return client_config + + +@pytest.fixture +async def client(mocker, client_config): + client = ProvisioningMQTTClient(client_config) + # Mock just the network operations from the MQTTClient, not the whole thing. + # This makes using the generators easier + client._mqtt_client.connect = mocker.AsyncMock() + client._mqtt_client.disconnect = mocker.AsyncMock() + client._mqtt_client.subscribe = mocker.AsyncMock() + client._mqtt_client.unsubscribe = mocker.AsyncMock() + client._mqtt_client.publish = mocker.AsyncMock() + # Also mock the set credentials method since we test that + client._mqtt_client.set_credentials = mocker.MagicMock() + client._mqtt_client.is_connected = mocker.MagicMock() + + # NOTE: No need to invoke .start() here, at least, not yet. + return client + + +@pytest.mark.describe("ProvisioningMQTTClient -- Instantiation") +class TestProvisioningMQTTClientInstantiation: + @pytest.mark.it("Stores the `registration_id` from the ProvisioningClientConfig as attributes") + async def test_simple_ids(self, client_config): + client = ProvisioningMQTTClient(client_config) + assert client._registration_id == client_config.registration_id + + @pytest.mark.it("Derives the `username` and stores the result as an attribute") + async def test_username( + self, + client_config, + ): + client_config.registration_id = FAKE_REGISTRATION_ID + client_config.id_scope = FAKE_ID_SCOPE + + ua = user_agent.get_provisioning_user_agent() + url_encoded_user_agent = urllib.parse.quote(ua, safe="") + # NOTE: This assertion shows the URL encoding was meaningful + assert user_agent != url_encoded_user_agent + expected_username = "{id_scope}/registrations/{registration_id}/api-version={api_version}&ClientVersion={user_agent}".format( + id_scope=client_config.id_scope, + registration_id=client_config.registration_id, + api_version=constant.PROVISIONING_API_VERSION, + user_agent=url_encoded_user_agent, + ) + client = ProvisioningMQTTClient(client_config) + # The expected username was derived + assert client._username == expected_username + + @pytest.mark.it( + "Stores the `sastoken_provider` from the ProvisioningClientConfig as an attribute" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="No SasTokenProvider present"), + ], + ) + async def test_sastoken_provider(self, client_config, sastoken_provider): + client_config.registration_id = FAKE_REGISTRATION_ID + client_config.sastoken_provider = sastoken_provider + + client = ProvisioningMQTTClient(client_config) + assert client._sastoken_provider is sastoken_provider + + @pytest.mark.it( + "Creates an MQTTClient instance based on the configuration of ProvisioningClientConfig and stores it as an attribute" + ) + @pytest.mark.parametrize( + "websockets, expected_transport, expected_port, expected_ws_path", + [ + pytest.param(True, "websockets", 443, "/$iothub/websocket", id="WebSockets"), + pytest.param(False, "tcp", 8883, None, id="TCP"), + ], + ) + async def test_mqtt_client( + self, + mocker, + client_config, + websockets, + expected_transport, + expected_port, + expected_ws_path, + ): + # Configure the client_config based on params + client_config.registration_id = FAKE_REGISTRATION_ID + client_config.websockets = websockets + + # Patch the MQTTClient constructor + mock_constructor = mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + assert mock_constructor.call_count == 0 + + # Create the client under test + client = ProvisioningMQTTClient(client_config) + + # Assert that the MQTTClient was constructed as expected + assert mock_constructor.call_count == 1 + assert mock_constructor.call_args == mocker.call( + client_id=client_config.registration_id, + hostname=client_config.hostname, + port=expected_port, + transport=expected_transport, + keep_alive=client_config.keep_alive, + auto_reconnect=client_config.auto_reconnect, + reconnect_interval=DEFAULT_RECONNECT_INTERVAL, + ssl_context=client_config.ssl_context, + websockets_path=expected_ws_path, + proxy_options=client_config.proxy_options, + ) + assert client._mqtt_client is mock_constructor.return_value + + @pytest.mark.it("Adds incoming message filter on the MQTTClient for dps responses") + async def test_dps_response_filter(self, mocker, client_config): + client_config.registration_id = FAKE_REGISTRATION_ID + expected_topic = mqtt_topic.get_response_topic_for_subscribe() + + mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient) + client = ProvisioningMQTTClient(client_config) + + # NOTE: Multiple filters are added, but not all are covered in this test + assert ( + mocker.call(expected_topic) + in client._mqtt_client.add_incoming_message_filter.call_args_list + ) + + @pytest.mark.it("Creates an empty RequestLedger") + async def test_request_ledger(self, client_config): + client_config.registration_id = FAKE_REGISTRATION_ID + client = ProvisioningMQTTClient(client_config) + assert isinstance(client._request_ledger, rr.RequestLedger) + assert len(client._request_ledger) == 0 + + @pytest.mark.it("Sets the _register_responses_enabled flag to False") + async def test_dps_responses_enabled(self, client_config): + client_config.registration_id = FAKE_REGISTRATION_ID + client = ProvisioningMQTTClient(client_config) + assert client._register_responses_enabled is False + + @pytest.mark.it("Sets background task attributes to None") + async def test_bg_tasks(self, client_config): + client_config.registration_id = FAKE_REGISTRATION_ID + client = ProvisioningMQTTClient(client_config) + assert client._process_dps_responses_task is None + + +@pytest.mark.describe("ProvisioningMQTTClient - .start()") +class TestProvisioningMQTTClientStart: + @pytest.mark.it( + "Sets the credentials on the MQTTClient, using the stored `username` as the username and no password, " + "when not using SAS authentication" + ) + async def test_mqtt_client_credentials_no_sas(self, mocker, client): + assert client._sastoken_provider is None + assert client._mqtt_client.set_credentials.call_count == 0 + + await client.start() + + assert client._mqtt_client.set_credentials.call_count == 1 + assert client._mqtt_client.set_credentials.call_args == mocker.call(client._username, None) + + # Cleanup + await client.stop() + + @pytest.mark.it( + "Sets the credentials on the MQTTClient, using the stored `username` as the username and the string-converted " + "current SasToken from the SasTokenProvider as the password, when using SAS authentication" + ) + async def test_mqtt_client_credentials_with_sas(self, client, mock_sastoken_provider): + client._sastoken_provider = mock_sastoken_provider + fake_sastoken = mock_sastoken_provider.get_current_sastoken.return_value + assert client._mqtt_client.set_credentials.call_count == 0 + + await client.start() + + assert client._mqtt_client.set_credentials.call_count == 1 + assert client._mqtt_client.set_credentials.call_args(client._username, str(fake_sastoken)) + + await client.stop() + + @pytest.mark.it( + "Begins running the ._process_dps_responses_task() coroutine method as a background task, storing it as an attribute" + ) + async def test_process_dps_responses_bg_task(self, client): + assert client._process_dps_responses_task is None + + await client.start() + + assert isinstance(client._process_dps_responses_task, asyncio.Task) + assert not client._process_dps_responses_task.done() + if sys.version_info > (3, 8): + # NOTE: There isn't a way to validate the contents of a task until 3.8 + # as far as I can tell. + task_coro = client._process_dps_responses_task.get_coro() + assert task_coro.__qualname__ == "ProvisioningMQTTClient._process_dps_responses" + + # Cleanup + await client.stop() + + +@pytest.mark.describe("ProvisioningMQTTClient - .stop()") +class TestProvisioningMQTTClientStop: + @pytest.fixture(autouse=True) + async def modify_client(self, client, mock_sastoken_provider): + client._sastoken_provider = mock_sastoken_provider + # Need to start the client so we can stop it. + await client.start() + + @pytest.mark.it("Disconnects the MQTTClient") + async def test_disconnect(self, mocker, client): + # NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the + # ProvisioningMQTTClient instead, since it's been fully tested elsewhere, and we assume + # correctness, lest we have to repeat all .disconnect() tests here. + original_disconnect = client.disconnect + client.disconnect = mocker.AsyncMock() + try: + assert client.disconnect.await_count == 0 + + await client.stop() + + assert client.disconnect.await_count == 1 + assert client.disconnect.await_args == mocker.call() + finally: + client.disconnect = original_disconnect + + @pytest.mark.it( + "Cancels the 'process_dps_responses' background task and removes it, if it exists" + ) + async def test_process_dps_responses_bg_task(self, client): + assert isinstance(client._process_dps_responses_task, asyncio.Task) + t = client._process_dps_responses_task + assert not t.done() + + await client.stop() + + assert t.done() + assert t.cancelled() + assert client._process_dps_responses_task is None + + # NOTE: Currently this is an invalid scenario. This shouldn't happen, but test it anyway. + @pytest.mark.it("Handles the case where no 'process_dps_responses' background task exists") + async def test_process_dps_responses_bg_task_no_exist(self, client): + # The task is already running, so cancel and remove it + assert isinstance(client._process_dps_responses_task, asyncio.Task) + client._process_dps_responses_task.cancel() + client._process_dps_responses_task = None + + await client.stop() + # No AttributeError means success! + + @pytest.mark.it( + "Allows any exception raised during MQTTClient disconnect to propagate, but only after cancelling background tasks" + ) + @pytest.mark.parametrize("exception", mqtt_disconnect_exceptions) + async def test_disconnect_raises(self, mocker, client, exception): + # NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the + # ProvisioningMQTTClient instead, since it's been fully tested elsewhere, and we assume + # correctness, lest we have to repeat all .disconnect() tests here. + original_disconnect = client.disconnect + client.disconnect = mocker.AsyncMock(side_effect=exception) + try: + process_dps_responses_bg_task = client._process_dps_responses_task + assert not process_dps_responses_bg_task.done() + + with pytest.raises(type(exception)) as e_info: + await client.stop() + assert e_info.value is exception + + # Background tasks were also cancelled despite the exception + assert process_dps_responses_bg_task.done() + assert process_dps_responses_bg_task.cancelled() + # And they were unset too + assert client._process_dps_responses_task is None + finally: + # Unset the mock so that tests can clean up + client.disconnect = original_disconnect + + # TODO: when run by itself, this test leaves a task unresolved. Not sure why. Not too important. + @pytest.mark.it( + "Does not alter any background tasks if already stopped, but does disconnect again" + ) + async def test_already_stopped(self, mocker, client): + original_disconnect = client.disconnect + client.disconnect = mocker.AsyncMock() + try: + assert client.disconnect.await_count == 0 + + # Stop + await client.stop() + assert client._process_dps_responses_task is None + assert client.disconnect.await_count == 1 + + # Stop again + await client.stop() + assert client._process_dps_responses_task is None + assert client.disconnect.await_count == 2 + + finally: + client.disconnect = original_disconnect + + # TODO: when run by itself, this test leaves a task unresolved. Not sure why. Not too important. + @pytest.mark.it( + "Can be cancelled while waiting for the MQTTClient disconnect to finish, but it won't stop background task cancellation" + ) + async def test_cancel_disconnect(self, client): + # NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the + # ProvisioningMQTTClient instead, since it's been fully tested elsewhere, and we assume + # correctness, lest we have to repeat all .disconnect() tests here. + original_disconnect = client.disconnect + client.disconnect = custom_mock.HangingAsyncMock() + try: + + process_dps_responses_bg_task = client._process_dps_responses_task + assert not process_dps_responses_bg_task.done() + + t = asyncio.create_task(client.stop()) + + # Hanging, waiting for disconnect to finish + await client.disconnect.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + # Due to cancellation, the tasks we want to assert are done may need a moment + # to finish, since we aren't waiting on them to exit + await asyncio.sleep(0.1) + + # And yet the background tasks still were cancelled anyway + assert process_dps_responses_bg_task.done() + assert process_dps_responses_bg_task.cancelled() + # And they were unset too + assert client._process_dps_responses_task is None + finally: + # Unset the mock so that tests can clean up. + client.disconnect = original_disconnect + + @pytest.mark.it( + "Can be cancelled while waiting for the background tasks to finish cancellation, but it won't stop the background task cancellation" + ) + async def test_cancel_gather(self, mocker, client): + original_gather = asyncio.gather + asyncio.gather = custom_mock.HangingAsyncMock() + spy_register_dps_response_bg_task_cancel = mocker.spy( + client._process_dps_responses_task, "cancel" + ) + try: + process_dps_responses_bg_task = client._process_dps_responses_task + assert not process_dps_responses_bg_task.done() + + t = asyncio.create_task(client.stop()) + + # Hanging waiting for gather to return (indicating tasks are all done cancellation) + await asyncio.gather.wait_for_hang() + assert not t.done() + # Background tests may or may not have completed cancellation yet, hard to test accurately. + # But their cancellation HAS been requested. + assert spy_register_dps_response_bg_task_cancel.call_count == 1 + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # Tasks will be cancelled very soon (if they aren't already) + await asyncio.sleep(0.1) + assert process_dps_responses_bg_task.done() + assert process_dps_responses_bg_task.cancelled() + # And they were unset too + assert client._process_dps_responses_task is None + finally: + # Unset the mock so that tests can clean up. + asyncio.gather = original_gather + + +@pytest.mark.describe("ProvisioningMQTTClient - .connect()") +class TestProvisioningMQTTClientConnect: + @pytest.mark.it("Awaits a connect using the MQTTClient") + async def test_mqtt_connect(self, mocker, client): + assert client._mqtt_client.connect.await_count == 0 + + await client.connect() + + assert client._mqtt_client.connect.await_count == 1 + assert client._mqtt_client.connect.await_args == mocker.call() + + @pytest.mark.it("Allows any exceptions raised during the MQTTClient connect to propagate") + @pytest.mark.parametrize("exception", mqtt_connect_exceptions) + async def test_mqtt_exception(self, client, exception): + client._mqtt_client.connect.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.connect() + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient connect to finish") + async def test_cancel(self, client): + client._mqtt_client.connect = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.connect()) + + # Hanging, waiting for MQTT connect to finish + await client._mqtt_client.connect.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("ProvisioningMQTTClient - .disconnect()") +class TestProvisioningMQTTClientDisconnect: + @pytest.mark.it("Awaits a disconnect using the MQTTClient") + async def test_mqtt_disconnect(self, mocker, client): + assert client._mqtt_client.disconnect.await_count == 0 + + await client.disconnect() + + assert client._mqtt_client.disconnect.await_count == 1 + assert client._mqtt_client.disconnect.await_args == mocker.call() + + @pytest.mark.it("Allows any exceptions raised during the MQTTClient disconnect to propagate") + @pytest.mark.parametrize("exception", mqtt_disconnect_exceptions) + async def test_mqtt_exception(self, client, exception): + client._mqtt_client.disconnect.side_effect = exception + try: + with pytest.raises(type(exception)) as e_info: + await client.disconnect() + assert e_info.value is exception + finally: + # Unset the side effect for cleanup (since shutdown uses disconnect) + client._mqtt_client.disconnect.side_effect = None + + @pytest.mark.it("Can be cancelled while waiting for the MQTTClient disconnect to finish") + async def test_cancel(self, mocker, client): + client._mqtt_client.disconnect = custom_mock.HangingAsyncMock() + try: + t = asyncio.create_task(client.disconnect()) + + # Hanging, waiting for MQTT disconnect to finish + await client._mqtt_client.disconnect.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + finally: + # Unset the HangingMock for clean (since shutdown uses disconnect) + client._mqtt_client.disconnect = mocker.AsyncMock() + + +@pytest.mark.describe("ProvisioningMQTTClient - .send_register()") +class TestProvisioningMQTTClientSendRegister: + @pytest.fixture(autouse=True) + def modify_publish(self, client): + # Add a side effect to publish that will complete the pending request for that request id. + # This will allow most tests to be able to ignore request/response infrastructure mocks. + # If this is not the desired behavior (such as in tests OF the request/response paradigm) + # override the publish behavior. + # + # To see tests regarding how this actually works in practice, see the relevant test suite + async def fake_publish(topic, payload): + rid = topic[topic.rfind("$rid=") :].split("=")[1] + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": {"registrationId": FAKE_REGISTRATION_ID}, + } + response = rr.Response(rid, 200, json.dumps(response_body_dict)) + await client._request_ledger.match_response(response) + + client._mqtt_client.publish.side_effect = fake_publish + + @pytest.fixture(params=["With Payload", "No Payload"]) + def registration_payload(self, request): + if request.param == "With Payload": + return "some payload" + else: + return None + + @pytest.mark.it( + "Awaits a subscribe to the register dps response topic using the MQTTClient, if register dps responses have not already been enabled" + ) + async def test_mqtt_subscribe_not_enabled(self, mocker, client, registration_payload): + assert client._mqtt_client.subscribe.await_count == 0 + assert client._register_responses_enabled is False + expected_topic = mqtt_topic.get_response_topic_for_subscribe() + + await client.send_register(payload=registration_payload) + + assert client._mqtt_client.subscribe.await_count == 1 + assert client._mqtt_client.subscribe.await_args == mocker.call(expected_topic) + + @pytest.mark.it( + "Does not perform a subscribe if register dps responses have already been enabled" + ) + async def test_mqtt_subscribe_already_enabled(self, client, registration_payload): + assert client._mqtt_client.subscribe.await_count == 0 + client._register_responses_enabled = True + + await client.send_register(payload=registration_payload) + + assert client._mqtt_client.subscribe.call_count == 0 + + @pytest.mark.it("Sets the register_responses_enabled flag to True upon subscribe success") + async def test_response_enabled_flag_success(self, client, registration_payload): + assert client._register_responses_enabled is False + + await client.send_register(payload=registration_payload) + + assert client._register_responses_enabled is True + + @pytest.mark.it("Generates a new Request, using the RequestLedger stored on the client") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_generate_request(self, mocker, client, responses_enabled, registration_payload): + client._register_responses_enabled = responses_enabled + spy_create_request = mocker.spy(client._request_ledger, "create_request") + + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_REGISTER_REQUEST_ID) + + await client.send_register(payload=registration_payload) + + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_REGISTER_REQUEST_ID) + + @pytest.mark.it( + "Awaits a publish to the register request topic using the MQTTClient, sending a JSON object containing the client's registration ID and the provided registration payload" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_mqtt_publish(self, mocker, client, responses_enabled, registration_payload): + client._register_responses_enabled = responses_enabled + spy_create_request = mocker.spy(client._request_ledger, "create_request") + assert client._mqtt_client.publish.await_count == 0 + + await client.send_register(payload=registration_payload) + + request = spy_create_request.spy_return + expected_topic = mqtt_topic.get_register_topic_for_publish(request.request_id) + payload_dict = {"registrationId": client._registration_id, "payload": registration_payload} + expected_payload = json.dumps(payload_dict) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_payload + ) + + @pytest.mark.it("Supports any JSON-serializable payload") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + @pytest.mark.parametrize( + "registration_payload", + [ + pytest.param("String Payload", id="String Payload"), + pytest.param(1234, id="Int Payload"), + pytest.param(2.0, id="Float Payload"), + pytest.param(True, id="Boolean Payload"), + pytest.param([1, 2, 3], id="List Payload"), + pytest.param({"some": {"dictionary": "value"}}, id="Dictionary Payload"), + pytest.param((1, 2), id="Tuple Payload"), + pytest.param(None, id="No Payload"), + ], + ) + async def test_json_payload(self, mocker, client, responses_enabled, registration_payload): + client._register_responses_enabled = responses_enabled + spy_create_request = mocker.spy(client._request_ledger, "create_request") + assert client._mqtt_client.publish.await_count == 0 + + await client.send_register(payload=registration_payload) + + request = spy_create_request.spy_return + expected_topic = mqtt_topic.get_register_topic_for_publish(request.request_id) + payload_dict = {"registrationId": client._registration_id, "payload": registration_payload} + expected_payload = json.dumps(payload_dict) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call( + expected_topic, expected_payload + ) + + @pytest.mark.it("Awaits a received Response to the Request") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_get_response(self, mocker, client, responses_enabled, registration_payload): + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_REGISTER_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": {"registrationId": FAKE_REGISTRATION_ID}, + } + mock_response.body = json.dumps(response_body_dict) + mock_request.get_response.return_value = mock_response + + await client.send_register(payload=registration_payload) + + assert mock_request.get_response.await_count == 1 + assert mock_request.get_response.await_args == mocker.call() + + @pytest.mark.it( + "Raises an ProvisioningServiceError if an unsuccessful status (300-429) is received via the Response" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + @pytest.mark.parametrize( + "failed_status", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(428, id="Status Code: 428"), + ], + ) + async def test_failed_response( + self, mocker, client, responses_enabled, failed_status, registration_payload + ): + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_REGISTER_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = failed_status + mock_response.body = " " + mock_request.get_response.return_value = mock_response + + with pytest.raises(exc.ProvisioningServiceError): + await client.send_register(payload=registration_payload) + + @pytest.mark.it( + "Returns the registration result received in the Response, converted to JSON, if the Response status was successful" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_success_response(self, mocker, client, responses_enabled, registration_payload): + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_REGISTER_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + + fake_response_string = json.dumps(response_body_dict) + mock_response.body = fake_response_string + mock_request.get_response.return_value = mock_response + + registration_response = await client.send_register(payload=registration_payload) + assert registration_response == json.loads(fake_response_string) + + @pytest.mark.it( + "Calls the send_register method thrice after different interval retry after values and " + "then finally returns the registration result received in the Response, converted to JSON, " + "when the Response status was successful on the last attempt" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_retry_response(self, mocker, client, responses_enabled, registration_payload): + # Mock out the sleep so that these tests don't run so slowly + mock_sleep = mocker.patch.object(asyncio, "sleep") + + retry_after_val_1 = 1 + retry_after_val_2 = 2 + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_REGISTER_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return 3 different responses + mock_response_1 = mocker.MagicMock(spec=rr.Response) + mock_response_1.status = 429 + mock_response_1.body = " " + mock_response_1.properties = {"retry-after": str(retry_after_val_1)} + + mock_response_2 = mocker.MagicMock(spec=rr.Response) + mock_response_2.status = 429 + mock_response_2.body = " " + mock_response_2.properties = {"retry-after": str(retry_after_val_2)} + mock_response_3 = mocker.MagicMock(spec=rr.Response) + mock_response_3.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + fake_response_string = json.dumps(response_body_dict) + mock_response_3.body = fake_response_string + + mock_request.get_response.side_effect = [mock_response_1, mock_response_2, mock_response_3] + mocker.patch.object(uuid, "uuid4", return_value=mock_request.request_id) + + expected_topic_register = mqtt_topic.get_register_topic_for_publish(mock_request.request_id) + payload_dict = {"registrationId": FAKE_REGISTRATION_ID, "payload": registration_payload} + expected_registration_payload = json.dumps(payload_dict) + + registration_response = await client.send_register(payload=registration_payload) + + assert mock_sleep.call_count == 3 + + # First sleep is 0 and then retry after + assert mock_sleep.call_args_list == [ + mocker.call(0), + mocker.call(retry_after_val_1), + mocker.call(retry_after_val_2), + ] + + assert client._mqtt_client.publish.await_count == 3 + # all publish calls happen with same topic nad same payload + assert client._mqtt_client.publish.await_args_list == [ + mocker.call(expected_topic_register, expected_registration_payload), + mocker.call(expected_topic_register, expected_registration_payload), + mocker.call(expected_topic_register, expected_registration_payload), + ] + assert registration_response == json.loads(fake_response_string) + + @pytest.mark.it( + "Calls the send_register on the register topic method and then the send_polling on a different topic after a " + "polling interval of 2 secs and then finally returns the registration result received in the Response, " + "converted to JSON, when the Response status was successful on the last attempt" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_polling_response(self, mocker, client, responses_enabled, registration_payload): + # Mock out the sleep so that these tests don't run so slowly + mock_sleep = mocker.patch.object(asyncio, "sleep") + + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_REGISTER_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + + # Mock out the request to return 2 different responses + mock_response_1 = mocker.MagicMock(spec=rr.Response) + mock_response_1.status = 202 + response_body_dict = {"operationId": FAKE_OPERATION_ID, "status": "assigning"} + fake_response_string = json.dumps(response_body_dict) + mock_response_1.body = fake_response_string + + mock_response_2 = mocker.MagicMock(spec=rr.Response) + mock_response_2.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + fake_response_string = json.dumps(response_body_dict) + mock_response_2.body = fake_response_string + + mock_request.get_response.side_effect = [mock_response_1, mock_response_2] + + # Mock uuid4 to return the fake request id + mocker.patch.object(uuid, "uuid4", return_value=mock_request.request_id) + + expected_topic_register = mqtt_topic.get_register_topic_for_publish(mock_request.request_id) + expected_topic_query = mqtt_topic.get_status_query_topic_for_publish( + mock_request.request_id, FAKE_OPERATION_ID + ) + payload_dict = {"registrationId": FAKE_REGISTRATION_ID, "payload": registration_payload} + expected_registration_payload = json.dumps(payload_dict) + + registration_response = await client.send_register(payload=registration_payload) + + assert mock_sleep.call_count == 2 + + # First sleep is 0 and then DEFAULT_POLLING_INTERVAL + assert mock_sleep.call_args_list == [ + mocker.call(0), + mocker.call(DEFAULT_POLLING_INTERVAL), + ] + + assert client._mqtt_client.publish.await_count == 2 + assert client._mqtt_client.publish.await_args_list == [ + mocker.call(expected_topic_register, expected_registration_payload), + mocker.call(expected_topic_query, " "), + ] + assert registration_response == json.loads(fake_response_string) + + # NOTE: MQTTClient subscribe can generate it's own cancellations due to network failure. + # This is different from a user-initiated cancellation + @pytest.mark.it("Allows any exceptions raised from the MQTTClient subscribe to propagate") + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_mqtt_subscribe_exception(self, client, exception, registration_payload): + assert client._register_responses_enabled is False + client._mqtt_client.subscribe.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.send_register(payload=registration_payload) + assert e_info.value is exception + + # NOTE: MQTTClient subscribe can generate it's own cancellations due to network failure. + # This is different from a user-initiated cancellation + @pytest.mark.it( + "Does not set the register_responses_enabled flag to True or create a Request if MQTTClient subscribe raises" + ) + @pytest.mark.parametrize("exception", mqtt_subscribe_exceptions) + async def test_subscribe_exception_cleanup( + self, mocker, client, exception, registration_payload + ): + assert client._register_responses_enabled is False + spy_create_request = mocker.spy(client._request_ledger, "create_request") + client._mqtt_client.subscribe.side_effect = exception + + with pytest.raises(type(exception)): + await client.send_register(payload=registration_payload) + + assert client._register_responses_enabled is False + assert spy_create_request.await_count == 0 + + # NOTE: This is a user invoked cancel, as opposed to one above, which was generated by the + # MQTTClient in response to a network failure. + @pytest.mark.it( + "Does not set the register_responses_enabled flag to True or create a Request if cancelled while waiting for the MQTTClient subscribe to finish" + ) + async def test_mqtt_subscribe_cancel_cleanup(self, mocker, client, registration_payload): + assert client._register_responses_enabled is False + spy_create_request = mocker.spy(client._request_ledger, "create_request") + client._mqtt_client.subscribe = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(client.send_register(payload=registration_payload)) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.subscribe.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + assert client._register_responses_enabled is False + assert spy_create_request.await_count == 0 + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient publish to propagate") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception(self, client, exception, registration_payload): + client._mqtt_client.publish.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.send_register(payload=registration_payload) + assert e_info.value is exception + + @pytest.mark.it("Deletes the Request from the RequestLedger if MQTTClient publish raises") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception_cleanup( + self, mocker, client, exception, registration_payload + ): + client._mqtt_client.publish.side_effect = exception + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_REGISTER_REQUEST_ID) + + with pytest.raises(type(exception)): + await client.send_register(payload=registration_payload) + + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_REGISTER_REQUEST_ID) + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for the MQTTClient publish to finish" + ) + async def test_mqtt_publish_cancel_cleanup(self, mocker, client, registration_payload): + client._mqtt_client.publish = custom_mock.HangingAsyncMock() + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_REGISTER_REQUEST_ID) + + t = asyncio.create_task(client.send_register(payload=registration_payload)) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.publish.wait_for_hang() + assert not t.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_REGISTER_REQUEST_ID) + assert spy_delete_request.await_count == 0 + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # The Request that was created has now been deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for a register dps response" + ) + async def test_waiting_response_cancel_cleanup(self, mocker, client, registration_payload): + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + + # Mock Request creation to return a specific, mocked request that hangs on + # awaiting a Response + request = rr.Request("fake_request_id") + request.get_response = custom_mock.HangingAsyncMock() + mocker.patch.object(rr, "Request", return_value=request) + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_REGISTER_REQUEST_ID) + + send_task = asyncio.create_task(client.send_register(payload=registration_payload)) + + # Hanging, waiting for response + await request.get_response.wait_for_hang() + assert not send_task.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_REGISTER_REQUEST_ID) + assert spy_delete_request.await_count == 0 + + # Cancel + send_task.cancel() + with pytest.raises(asyncio.CancelledError): + await send_task + + # The Request that was created was also deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call(request.request_id) + + @pytest.mark.it("Waits to retrieve response within a default period of time") + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_waiting_response_within_default_timeout( + self, mocker, client, responses_enabled, registration_payload + ): + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_REGISTER_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + + # Mock out the response to return completely successful response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + fake_response_string = json.dumps(response_body_dict) + mock_response.body = fake_response_string + + mock_request.get_response.return_value = mock_response + + spy_wait_for = mocker.spy(asyncio, "wait_for") + + await client.send_register(payload=registration_payload) + + assert spy_wait_for.await_count == 1 + assert asyncio.iscoroutine(spy_wait_for.await_args.args[0]) + assert spy_wait_for.await_args.args[1] == DEFAULT_TIMEOUT_INTERVAL + + @pytest.mark.it( + "Raises ProvisioningServiceError from TimeoutError if the response is not retrieved within a default period" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_waiting_response_timeout_exception( + self, mocker, client, responses_enabled, registration_payload + ): + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_REGISTER_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + + # Simulate a timeout + fake_timeout = asyncio.TimeoutError("fake timeout exception") + + def simulate_timeout(awaitable, timeout): + # Create a task from the awaitable, and then cancel it so that no warnings are thrown + # about the awaitable never being awaited. + # This is because in the real implementation, the awaitable is scheduled on the + # event loop, and then later cancelled in the case of timeout + t = asyncio.create_task(awaitable) + t.cancel() + raise fake_timeout + + mocker.patch.object(asyncio, "wait_for", side_effect=simulate_timeout) + + # Timeout causes ProvisioningServiceError + with pytest.raises(exc.ProvisioningServiceError) as e_info: + await client.send_register(payload=registration_payload) + + assert e_info.value.__cause__ is fake_timeout + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if timeout occurred while waiting for a register dps response" + ) + @pytest.mark.parametrize( + "responses_enabled", + [ + pytest.param(True, id="Register Responses Already Enabled"), + pytest.param(False, id="Register Responses Not Yet Enabled"), + ], + ) + async def test_waiting_response_timeout_cleanup( + self, mocker, client, responses_enabled, registration_payload + ): + client._register_responses_enabled = responses_enabled + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_REGISTER_REQUEST_ID) + + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + # Simulate a timeout + fake_timeout = asyncio.TimeoutError("fake timeout exception") + + def simulate_timeout(awaitable, timeout): + # Create a task from the awaitable, and then cancel it so that no warnings are thrown + # about the awaitable never being awaited. + # This is because in the real implementation, the awaitable is scheduled on the + # event loop, and then later cancelled in the case of timeout + t = asyncio.create_task(awaitable) + t.cancel() + raise fake_timeout + + mocker.patch.object(asyncio, "wait_for", side_effect=simulate_timeout) + + # Timeout occurs + with pytest.raises(exc.ProvisioningServiceError) as e_info: + await client.send_register(payload=registration_payload) + assert e_info.value.__cause__ is fake_timeout + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + + # The Request that was created was also deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call(FAKE_REGISTER_REQUEST_ID) + + +@pytest.mark.describe("ProvisioningMQTTClient - .send_polling()") +class TestProvisioningMQTTClientSendPolling: + @pytest.fixture(autouse=True) + def modify_publish(self, client): + # Add a side effect to publish that will complete the pending request for that request id. + # This will allow most tests to be able to ignore request/response infrastructure mocks. + # If this is not the desired behavior (such as in tests OF the request/response paradigm) + # override the publish behavior. + # + # To see tests regarding how this actually works in practice, see the relevant test suite + async def fake_publish(topic, payload): + rid = topic.split("&")[0].split("=")[1] + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": {"registrationId": FAKE_REGISTRATION_ID}, + } + response = rr.Response(rid, 200, json.dumps(response_body_dict)) + await client._request_ledger.match_response(response) + + client._mqtt_client.publish.side_effect = fake_publish + + @pytest.fixture(autouse=True) + def mock_sleep(self, mocker): + # .send_polling() always involves asyncio sleeps, which dramatically slows down + # the unit tests. Mock out these sleeps to speed them up. + return mocker.patch.object(asyncio, "sleep") + + @pytest.mark.it("Generates a new Request, using the RequestLedger stored on the client") + async def test_generate_request(self, mocker, client): + spy_create_request = mocker.spy(client._request_ledger, "create_request") + + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_POLLING_REQUEST_ID) + + await client.send_polling(operation_id=FAKE_OPERATION_ID) + + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_POLLING_REQUEST_ID) + + @pytest.mark.it("Awaits a publish to the polling request topic using the MQTTClient") + async def test_mqtt_publish(self, mocker, client): + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_POLLING_REQUEST_ID) + + spy_create_request = mocker.spy(client._request_ledger, "create_request") + + assert client._mqtt_client.publish.await_count == 0 + + await client.send_polling(operation_id=FAKE_OPERATION_ID) + + request = spy_create_request.spy_return + expected_topic = mqtt_topic.get_status_query_topic_for_publish( + request.request_id, FAKE_OPERATION_ID + ) + + assert client._mqtt_client.publish.await_count == 1 + assert client._mqtt_client.publish.await_args == mocker.call(expected_topic, " ") + + @pytest.mark.it("Awaits a received Response to the Request") + async def test_get_response(self, mocker, client): + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_POLLING_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": {"registrationId": FAKE_REGISTRATION_ID}, + } + mock_response.body = json.dumps(response_body_dict) + mock_request.get_response.return_value = mock_response + + await client.send_polling(FAKE_OPERATION_ID) + + assert mock_request.get_response.await_count == 1 + assert mock_request.get_response.await_args == mocker.call() + + @pytest.mark.it( + "Raises an ProvisioningServiceError if an unsuccessful status (300-429) is received via the Response" + ) + @pytest.mark.parametrize( + "failed_status", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(428, id="Status Code: 428"), + ], + ) + async def test_failed_response(self, mocker, client, failed_status): + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_POLLING_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = failed_status + mock_response.body = " " + mock_request.get_response.return_value = mock_response + + with pytest.raises(exc.ProvisioningServiceError): + await client.send_polling(FAKE_OPERATION_ID) + + @pytest.mark.it( + "Returns the registration result received in the Response, converted to JSON, if the Response status was successful" + ) + async def test_success_response(self, mocker, client): + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_POLLING_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return a response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + + fake_response_string = json.dumps(response_body_dict) + mock_response.body = fake_response_string + mock_request.get_response.return_value = mock_response + + registration_response = await client.send_polling(FAKE_OPERATION_ID) + assert registration_response == json.loads(fake_response_string) + + @pytest.mark.it( + "Calls the send_polling method thrice with different interval and retry after values and " + "then finally returns the registration result received in the Response, converted to JSON, " + "when the Response status was successful on the last attempt" + ) + async def test_retry_response(self, mocker, client, mock_sleep): + retry_after_val_1 = 1 + retry_after_val_2 = 2 + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_POLLING_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + # Mock out the request to return 3 different responses + mock_response_1 = mocker.MagicMock(spec=rr.Response) + mock_response_1.status = 429 + mock_response_1.body = " " + mock_response_1.properties = {"retry-after": str(retry_after_val_1)} + + mock_response_2 = mocker.MagicMock(spec=rr.Response) + mock_response_2.status = 429 + mock_response_2.body = " " + mock_response_2.properties = {"retry-after": str(retry_after_val_2)} + mock_response_3 = mocker.MagicMock(spec=rr.Response) + mock_response_3.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + fake_response_string = json.dumps(response_body_dict) + mock_response_3.body = fake_response_string + + mock_request.get_response.side_effect = [mock_response_1, mock_response_2, mock_response_3] + mocker.patch.object(uuid, "uuid4", return_value=mock_request.request_id) + + expected_topic_query = mqtt_topic.get_status_query_topic_for_publish( + mock_request.request_id, FAKE_OPERATION_ID + ) + + registration_response = await client.send_polling(FAKE_OPERATION_ID) + + assert mock_sleep.call_count == 3 + + # First sleep is default polling interval and then retry after + assert mock_sleep.call_args_list == [ + mocker.call(2), + mocker.call(retry_after_val_1), + mocker.call(retry_after_val_2), + ] + + assert client._mqtt_client.publish.await_count == 3 + # all publish calls happen with same topic nad same payload + assert client._mqtt_client.publish.await_args_list == [ + mocker.call(expected_topic_query, " "), + mocker.call(expected_topic_query, " "), + mocker.call(expected_topic_query, " "), + ] + assert registration_response == json.loads(fake_response_string) + + @pytest.mark.it( + "Calls the send_polling on the query topic 2 times after a polling interval of 2 secs and then " + "finally returns the registration result received in the Response, " + "converted to JSON, when the Response status was successful on the last attempt" + ) + async def test_polling_response(self, mocker, client, mock_sleep): + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_POLLING_REQUEST_ID # Need this for string manipulation + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + + # Mock out the request to return 2 different responses + mock_response_1 = mocker.MagicMock(spec=rr.Response) + mock_response_1.status = 202 + response_body_dict = {"operationId": FAKE_OPERATION_ID, "status": "assigning"} + fake_response_string = json.dumps(response_body_dict) + mock_response_1.body = fake_response_string + # EMoty dict with no retry after + mock_response_1.properties = {} + + mock_response_2 = mocker.MagicMock(spec=rr.Response) + mock_response_2.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + fake_response_string = json.dumps(response_body_dict) + mock_response_2.body = fake_response_string + + # Need to use side effect instead of return value to generate different responses + mock_request.get_response.side_effect = [mock_response_1, mock_response_2] + + # Mock uuid4 to return the fake request id + mocker.patch.object(uuid, "uuid4", return_value=mock_request.request_id) + + expected_topic_query = mqtt_topic.get_status_query_topic_for_publish( + mock_request.request_id, FAKE_OPERATION_ID + ) + + registration_response = await client.send_polling(FAKE_OPERATION_ID) + + assert mock_sleep.call_count == 2 + + # Both sleep are DEFAULT_POLLING_INTERVAL + assert mock_sleep.call_args_list == [ + mocker.call(DEFAULT_POLLING_INTERVAL), + mocker.call(DEFAULT_POLLING_INTERVAL), + ] + + assert client._mqtt_client.publish.await_count == 2 + assert client._mqtt_client.publish.await_args_list == [ + mocker.call(expected_topic_query, " "), + mocker.call(expected_topic_query, " "), + ] + assert registration_response == json.loads(fake_response_string) + + @pytest.mark.it("Allows any exceptions raised from the MQTTClient publish to propagate") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception(self, client, exception): + client._mqtt_client.publish.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await client.send_polling(FAKE_OPERATION_ID) + assert e_info.value is exception + + @pytest.mark.it("Deletes the Request from the RequestLedger if MQTTClient publish raises") + @pytest.mark.parametrize("exception", mqtt_publish_exceptions) + async def test_mqtt_publish_exception_cleanup(self, mocker, client, exception): + client._mqtt_client.publish.side_effect = exception + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_POLLING_REQUEST_ID) + + with pytest.raises(type(exception)): + await client.send_polling(FAKE_OPERATION_ID) + + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_POLLING_REQUEST_ID) + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for the MQTTClient publish to finish" + ) + async def test_mqtt_publish_cancel_cleanup(self, mocker, client): + client._mqtt_client.publish = custom_mock.HangingAsyncMock() + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_POLLING_REQUEST_ID) + + t = asyncio.create_task(client.send_polling(FAKE_OPERATION_ID)) + + # Hanging, waiting for MQTT publish to finish + await client._mqtt_client.publish.wait_for_hang() + assert not t.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_POLLING_REQUEST_ID) + assert spy_delete_request.await_count == 0 + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # The Request that was created has now been deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call( + spy_create_request.spy_return.request_id + ) + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if cancelled while waiting for a register dps response" + ) + async def test_waiting_response_cancel_cleanup(self, mocker, client): + # Override autocompletion behavior on publish (we don't want it here) + client._mqtt_client.publish = mocker.AsyncMock() + + # Mock Request creation to return a specific, mocked request that hangs on + # awaiting a Response + request = rr.Request(FAKE_POLLING_REQUEST_ID) + request.get_response = custom_mock.HangingAsyncMock() + mocker.patch.object(rr, "Request", return_value=request) + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_POLLING_REQUEST_ID) + + send_task = asyncio.create_task(client.send_polling(FAKE_OPERATION_ID)) + + # Hanging, waiting for response + await request.get_response.wait_for_hang() + assert not send_task.done() + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + assert spy_create_request.await_args == mocker.call(FAKE_POLLING_REQUEST_ID) + assert spy_delete_request.await_count == 0 + + # Cancel + send_task.cancel() + with pytest.raises(asyncio.CancelledError): + await send_task + + # The Request that was created was also deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call(request.request_id) + + @pytest.mark.it("Waits to retrieve response within a default period of time") + async def test_waiting_response_within_default_timeout(self, mocker, client): + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_POLLING_REQUEST_ID + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + + # Mock out the response to return completely successful response + mock_response = mocker.MagicMock(spec=rr.Response) + mock_response.status = 200 + response_body_dict = { + "operationId": FAKE_OPERATION_ID, + "status": FAKE_STATUS, + "registrationState": { + "assignedHub": FAKE_ASSIGNED_HUB, + "createdDateTimeUtc": None, + "deviceId": FAKE_DEVICE_ID, + "etag": None, + "lastUpdatedDateTimeUtc": None, + "payload": None, + "subStatus": FAKE_SUB_STATUS, + }, + } + fake_response_string = json.dumps(response_body_dict) + mock_response.body = fake_response_string + + mock_request.get_response.return_value = mock_response + + spy_wait_for = mocker.spy(asyncio, "wait_for") + + await client.send_polling(operation_id=FAKE_OPERATION_ID) + + assert spy_wait_for.await_count == 1 + assert asyncio.iscoroutine(spy_wait_for.await_args.args[0]) + assert spy_wait_for.await_args.args[1] == DEFAULT_TIMEOUT_INTERVAL + + @pytest.mark.it( + "Raises ProvisioningServiceError from TimeoutError if the response is not retrieved within a default period" + ) + async def test_waiting_response_timeout_exception(self, mocker, client): + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + # Mock out the ledger to return a mocked request + mock_request = mocker.MagicMock(spec=rr.Request) + mock_request.request_id = FAKE_POLLING_REQUEST_ID + mocker.patch.object(client._request_ledger, "create_request", return_value=mock_request) + + # Simulate a timeout + fake_timeout = asyncio.TimeoutError("fake timeout exception") + + def simulate_timeout(awaitable, timeout): + # Create a task from the awaitable, and then cancel it so that no warnings are thrown + # about the awaitable never being awaited. + # This is because in the real implementation, the awaitable is scheduled on the + # event loop, and then later cancelled in the case of timeout + t = asyncio.create_task(awaitable) + t.cancel() + raise fake_timeout + + mocker.patch.object(asyncio, "wait_for", side_effect=simulate_timeout) + + # Timeout causes ProvisioningServiceError + with pytest.raises(exc.ProvisioningServiceError) as e_info: + await client.send_polling(operation_id=FAKE_OPERATION_ID) + assert e_info.value.__cause__ is fake_timeout + + @pytest.mark.it( + "Deletes the Request from the RequestLedger if timeout occurred while waiting for a register dps response" + ) + async def test_waiting_response_timeout_cleanup(self, mocker, client): + # Override autocompletion behavior on publish (we don't need it here) + client._mqtt_client.publish = mocker.AsyncMock() + + # Mock the uuid call as well to return fake request id + mocker.patch.object(uuid, "uuid4", return_value=FAKE_POLLING_REQUEST_ID) + + spy_create_request = mocker.spy(client._request_ledger, "create_request") + spy_delete_request = mocker.spy(client._request_ledger, "delete_request") + + # Simulate a timeout + fake_timeout = asyncio.TimeoutError("fake timeout exception") + + def simulate_timeout(awaitable, timeout): + # Create a task from the awaitable, and then cancel it so that no warnings are thrown + # about the awaitable never being awaited. + # This is because in the real implementation, the awaitable is scheduled on the + # event loop, and then later cancelled in the case of timeout + t = asyncio.create_task(awaitable) + t.cancel() + raise fake_timeout + + mocker.patch.object(asyncio, "wait_for", side_effect=simulate_timeout) + + # Timeout occurs + with pytest.raises(exc.ProvisioningServiceError) as e_info: + await client.send_polling(operation_id=FAKE_OPERATION_ID) + assert e_info.value.__cause__ is fake_timeout + + # Request was created, but not yet deleted + assert spy_create_request.await_count == 1 + + # The Request that was created was also deleted + assert spy_delete_request.await_count == 1 + assert spy_delete_request.await_args == mocker.call(FAKE_POLLING_REQUEST_ID) + + +@pytest.mark.describe("ProvisioningMQTTClient - PROPERTY: .connected") +class TestProvisioningMQTTClientConnected: + @pytest.mark.it("Returns the result of the MQTTClient's .is_connected() method") + def test_returns_result(self, mocker, client): + assert client._mqtt_client.is_connected.call_count == 0 + + result = client.connected + + assert client._mqtt_client.is_connected.call_count == 1 + assert client._mqtt_client.is_connected.call_args == mocker.call() + assert result is client._mqtt_client.is_connected.return_value + + +@pytest.mark.describe("ProvisioningMQTTClient - BG TASK: ._process_dps_responses") +class TestProvisioningMQTTClientProcessDPSResponses: + response_payloads = [ + pytest.param('{"json": "in", "a": {"string": "format"}}', id="Some DPS Response"), + pytest.param(" ", id="DPS Empty Response"), + ] + + @pytest.mark.it( + "Creates a Response containing the request id and status code and extracted properties from the topic, " + "as well as the utf-8 decoded payload of the MQTTMessage, when the MQTTClient receives an " + "MQTTMessage on the dps response topic" + ) + @pytest.mark.parametrize( + "status", + [ + pytest.param(200, id="Status Code: 200"), + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + @pytest.mark.parametrize("payload_str", response_payloads) + async def test_response(self, mocker, client, status, payload_str): + # Mocks + mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + spy_response_factory = mocker.spy(rr, "Response") + # Set up MQTTMessages + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + rid = "some rid" + value1 = "fake value 1" + value2 = "fake value 2" + props = {"$rid": rid, "prop1": value1, "prop2": value2} + msg_topic = "$dps/registrations/res/{status}/?$rid={rid}&prop1={v1}&prop2={v2}".format( + status=status, rid=rid, v1=value1, v2=value2 + ) + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=msg_topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode("utf-8") + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + await asyncio.sleep(0.1) + + # No Responses have been created yet + assert spy_response_factory.call_count == 0 + + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + await asyncio.sleep(0.1) + + # Response was created + assert spy_response_factory.call_count == 1 + resp1 = spy_response_factory.spy_return + assert resp1.request_id == rid + assert resp1.status == status + assert resp1.body == payload_str + assert resp1.properties == props + + t.cancel() + + @pytest.mark.it("Matches the newly created Response on the RequestLedger") + @pytest.mark.parametrize("payload_str", response_payloads) + async def test_match(self, mocker, client, payload_str): + # Mock + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + spy_response_factory = mocker.spy(rr, "Response") + # Set up MQTTMessage + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, "some rid") + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=topic.encode("utf-8")) + mqtt_msg.payload = payload_str.encode("utf-8") + + # No Responses have been created yet + assert spy_response_factory.call_count == 0 + assert mock_ledger.match_response.call_count == 0 + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + await asyncio.sleep(0.1) + + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + await asyncio.sleep(0.1) + + # Response was created + assert spy_response_factory.call_count == 1 + resp1 = spy_response_factory.spy_return + assert mock_ledger.match_response.call_count == 1 + assert mock_ledger.match_response.call_args == mocker.call(resp1) + + t.cancel() + + @pytest.mark.it("Indefinitely repeats") + async def test_repeat(self, mocker, client): + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + spy_response_factory = mocker.spy(rr, "Response") + assert spy_response_factory.call_count == 0 + assert mock_ledger.match_response.call_count == 0 + + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, "some rid") + + t = asyncio.create_task(client._process_dps_responses()) + await asyncio.sleep(0.1) + + # Test that behavior repeats up to 10 times. No way to really prove infinite + i = 0 + assert mock_ledger.match_response.call_count == 0 + while i < 10: + i += 1 + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=topic.encode("utf-8")) + # Switch between Polling and Register responses + if i % 2 == 0: + mqtt_msg.payload = '{"operationId":"fake_operation_id","status":"assigning","registrationState":{"registrationId":"fake_reg_id","status":"assigning"}}'.encode( + "utf-8" + ) + else: + mqtt_msg.payload = ( + '{"operationId":"fake_operation_id","status":"assigning"}'.encode("utf-8") + ) + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + await asyncio.sleep(0.1) + # Response was created + assert spy_response_factory.call_count == i + + assert not t.done() + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting properties along with request id from the " + "MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_request_properties_extraction_fails(self, mocker, client, arbitrary_exception): + # Inject failure + original_fn = mqtt_topic.extract_properties_from_response_topic + mocker.patch.object( + mqtt_topic, + "extract_properties_from_response_topic", + side_effect=arbitrary_exception, + ) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Spy on the Response object + spy_response_factory = mocker.spy(rr, "Response") + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response was created due to the injected failure (but failure was suppressed) + assert spy_response_factory.call_count == 0 + mqtt_topic.extract_properties_from_response_topic.call_count == 1 + + # Un-inject the failure + mqtt_topic.extract_properties_from_response_topic = original_fn + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created, demonstrating that the task is still functional + assert spy_response_factory.call_count == 1 + resp2 = spy_response_factory.spy_return + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while extracting the status code from the MQTTMessage, " + "dropping the MQTTMessage and continuing" + ) + async def test_status_code_extraction_fails(self, mocker, client, arbitrary_exception): + # Inject failure + original_fn = mqtt_topic.extract_status_code_from_response_topic + mocker.patch.object( + mqtt_topic, + "extract_status_code_from_response_topic", + side_effect=arbitrary_exception, + ) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Spy on the Response object + spy_response_factory = mocker.spy(rr, "Response") + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response was created due to the injected failure (but failure was suppressed) + assert spy_response_factory.call_count == 0 + mqtt_topic.extract_status_code_from_response_topic.call_count == 1 + + # Un-inject the failure + mqtt_topic.extract_status_code_from_response_topic = original_fn + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created, demonstrating that the previous failure did not + # crash the task + assert spy_response_factory.call_count == 1 + resp2 = spy_response_factory.spy_return + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised while decoding the payload from the MQTTMessage, dropping the MQTTMessage and continuing" + ) + async def test_payload_decode_fails(self, mocker, client, arbitrary_exception): + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Spy on the Response object + spy_response_factory = mocker.spy(rr, "Response") + + # Inject failure into the first MQTTMessage's payload + mqtt_msg1.payload = mocker.MagicMock() + mqtt_msg1.payload.decode.side_effect = arbitrary_exception + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response was created due to the injected failure (but failure was suppressed) + assert spy_response_factory.call_count == 0 + assert mqtt_msg1.payload.decode.call_count == 1 + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created, demonstrating that the previous failure did not + # crash the task + assert spy_response_factory.call_count == 1 + resp2 = spy_response_factory.spy_return + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any unexpected exceptions raised instantiating the Response object from the MQTTMessage values, dropping the MQTTMessage and continuing" + ) + async def test_response_instantiation_fails(self, mocker, client, arbitrary_exception): + # Inject failure + original_cls = rr.Response + mocker.patch.object(rr, "Response", side_effect=arbitrary_exception) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Mock the ledger so we can see if it is used + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # No Response matched to the injected failure (but failure was suppressed) + assert mock_ledger.match_response.call_count == 0 + assert rr.Response.call_count == 1 + + # Un-inject the failure + rr.Response = original_cls + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # This time a Response was created and matched, demonstrating that the previous + # failure did not crash the task + assert mock_ledger.match_response.call_count == 1 + resp = mock_ledger.match_response.call_args[0][0] + assert resp.request_id == rid2 + assert resp.status == 200 + assert resp.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.it( + "Suppresses any exceptions raised while matching the Response, dropping the MQTTMessage and continuing" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(KeyError(), id="KeyError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_response_match_fails(self, mocker, client, exception): + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + + # Create two messages that are the same other than the request id + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + # Response #1 + rid1 = "rid1" + msg1_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid1) + mqtt_msg1 = mqtt.MQTTMessage(mid=1, topic=msg1_topic.encode("utf-8")) + mqtt_msg1.payload = " ".encode("utf-8") + # Response #2 + rid2 = "rid2" + msg2_topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, rid2) + mqtt_msg2 = mqtt.MQTTMessage(mid=1, topic=msg2_topic.encode("utf-8")) + mqtt_msg2.payload = " ".encode("utf-8") + + # Inject failure into the response match + mock_ledger.match_response.side_effect = exception + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + + # Load the first MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg1) + await asyncio.sleep(0.1) + + # Attempt to match response ocurred (and thus, failed, due to mock) + assert mock_ledger.match_response.call_count == 1 + + # Un-inject the failure + mock_ledger.match_response.side_effect = None + + # Load the second MQTTMessage + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg2) + await asyncio.sleep(0.1) + + # Another response match ocurred, demonstrating that the previous failure did not + # crash the task + assert mock_ledger.match_response.call_count == 2 + resp2 = mock_ledger.match_response.call_args[0][0] + assert resp2.request_id == rid2 + assert resp2.status == 200 + assert resp2.body == mqtt_msg2.payload.decode("utf-8") + + assert not t.done() + + t.cancel() + + @pytest.mark.skip(reason="Currently can't figure out how to mock a generator correctly") + @pytest.mark.it("Can be cancelled while waiting for an MQTTMessage to arrive") + async def test_cancelled_while_waiting_for_message(self): + pass + + @pytest.mark.it("Can be cancelled while matching Response") + async def test_cancelled_while_matching_response(self, mocker, client): + mock_ledger = mocker.patch.object(client, "_request_ledger", spec=rr.RequestLedger) + mock_ledger.match_response = custom_mock.HangingAsyncMock() + + # Set up MQTTMessage + generic_topic = mqtt_topic.get_response_topic_for_subscribe() + topic = generic_topic.rstrip("#") + "{}/?$rid={}".format(200, "some rid") + mqtt_msg = mqtt.MQTTMessage(mid=1, topic=topic.encode("utf-8")) + mqtt_msg.payload = " ".encode("utf-8") + + # Start task + t = asyncio.create_task(client._process_dps_responses()) + await asyncio.sleep(0.1) + + # Load the MQTTMessage into the MQTTClient's filtered message queue + await client._mqtt_client._incoming_filtered_messages[generic_topic].put(mqtt_msg) + + # Matching response is hanging + await mock_ledger.match_response.wait_for_hang() + + # Task can be cancelled + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("ProvisioningMQTTClient - .wait_for_disconnect()") +class TestProvisioningMQTTClientReportConnectionDrop: + @pytest.mark.it( + "Returns None if an expected disconnect has previously occurred in the MQTTClient" + ) + async def test_previous_expected_disconnect(self, client): + # Simulate expected disconnect + client._mqtt_client._disconnection_cause = None + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Reports no cause (i.e. expected disconnect) + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is None + + @pytest.mark.it( + "Waits for a disconnect to occur in the MQTTClient, and returns None once an expected disconnect occurs, if no disconnect has yet ocurred" + ) + async def test_expected_disconnect(self, client): + # No connection drop to report + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert not t.done() + + # Simulate expected disconnect + client._mqtt_client._disconnection_cause = None + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Report no cause (i.e. expected disconnect) + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is None + + @pytest.mark.it( + "Returns the MQTTError that caused an unexpected disconnect in the MQTTClient, if an unexpected disconnect has already occurred" + ) + async def test_previous_unexpected_disconnect(self, client): + # Simulate unexpected disconnect + cause = mqtt.MQTTError(rc=7) + client._mqtt_client._disconnection_cause = cause + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Reports the cause that is already available + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is cause + + @pytest.mark.it( + "Waits for a disconnect to occur in the MQTTClient, and returns the MQTTError that caused it once an unexpected disconnect occurs, if no disconnect has not yet ocurred" + ) + async def test_unexpected_disconnect(self, client): + # No connection drop to report yet + t = asyncio.create_task(client.wait_for_disconnect()) + await asyncio.sleep(0.1) + assert not t.done() + + # Simulate unexpected disconnect + cause = mqtt.MQTTError(rc=7) + client._mqtt_client._disconnection_cause = cause + client._mqtt_client.is_connected.return_value = False + async with client._mqtt_client.disconnected_cond: + client._mqtt_client.disconnected_cond.notify_all() + + # Cause can now be reported + await asyncio.sleep(0.1) + assert t.done() + assert t.result() is cause diff --git a/tests/unit/test_provisioning_session.py b/tests/unit/test_provisioning_session.py new file mode 100644 index 000000000..e96662062 --- /dev/null +++ b/tests/unit/test_provisioning_session.py @@ -0,0 +1,1233 @@ +import asyncio +import pytest +import ssl +import time +from dev_utils import custom_mock +from pytest_lazyfixture import lazy_fixture +from azure.iot.device.provisioning_session import ProvisioningSession +from azure.iot.device import config, constant +from azure.iot.device import exceptions as exc +from azure.iot.device import provisioning_mqtt_client as mqtt +from azure.iot.device import sastoken as st +from azure.iot.device import signing_mechanism as sm + +FAKE_REGISTRATION_ID = "fake_registration_id" +FAKE_ID_SCOPE = "fake_idscope" +FAKE_HOSTNAME = "fake.hostname" +FAKE_URI = "fake/resource/location" +FAKE_SHARED_ACCESS_KEY = "Zm9vYmFy" +FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" + +# ~~~~~ Helpers ~~~~~~ + + +def sastoken_generator_fn(): + return "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( + resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=str(int(time.time()) + 3600) + ) + + +def get_expected_uri(): + return "{id_scope}/registrations/{registration_id}".format( + id_scope=FAKE_ID_SCOPE, registration_id=FAKE_REGISTRATION_ID + ) + + +# ~~~~~ Fixtures ~~~~~~ + +# Mock out the underlying client in order to not do network operations +@pytest.fixture(autouse=True) +def mock_mqtt_provisioning_client(mocker): + mock_client = mocker.patch.object( + mqtt, "ProvisioningMQTTClient", spec=mqtt.ProvisioningMQTTClient + ).return_value + # Use a HangingAsyncMock here so that the coroutine does not return until we want it to + mock_client.wait_for_disconnect = custom_mock.HangingAsyncMock() + return mock_client + + +@pytest.fixture(autouse=True) +def mock_sastoken_provider(mocker): + return mocker.patch.object(st, "SasTokenProvider", spec=st.SasTokenProvider).return_value + + +@pytest.fixture +def custom_ssl_context(): + # NOTE: It doesn't matter how the SSLContext is configured for the tests that use this fixture, + # so it isn't configured at all. + return ssl.SSLContext() + + +@pytest.fixture(params=["Default SSLContext", "Custom SSLContext"]) +def optional_ssl_context(request, custom_ssl_context): + """Sometimes tests need to show something works with or without an SSLContext""" + if request.param == "Custom SSLContext": + return custom_ssl_context + else: + return None + + +@pytest.fixture +async def session(custom_ssl_context): + """Use a symmetric key configuration and custom SSL auth for simplicity""" + async with ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + ssl_context=custom_ssl_context, + ) as session: + yield session + + +@pytest.fixture +def disconnected_session(custom_ssl_context): + return ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + ssl_context=custom_ssl_context, + ) + + +# ~~~~~ Parametrizations ~~~~~ +# Define parametrizations that will be used across multiple test suites, and that may eventually +# need to be changed everywhere, e.g. new auth scheme added. +# Note that some parametrizations are also defined within the scope of a single test suite if that +# is the only unit they are relevant to. + + +# Parameters for arguments to the __init__ or factory methods. Represent different types of +# authentication. Use this parametrization whenever possible on .create() tests. +# NOTE: Do NOT combine this with the SSL fixtures above. This parametrization contains +# ssl contexts where necessary +create_auth_params = [ + # Provide args in form 'shared_access_key, sastoken_fn, ssl_context' + pytest.param( + FAKE_SHARED_ACCESS_KEY, None, None, id="Shared Access Key SAS Auth + Default SSLContext" + ), + pytest.param( + FAKE_SHARED_ACCESS_KEY, + None, + lazy_fixture("custom_ssl_context"), + id="Shared Access Key SAS Auth + Custom SSLContext", + ), + pytest.param( + None, + sastoken_generator_fn, + None, + id="User-Provided SAS Token Auth + Default SSLContext", + ), + pytest.param( + None, + sastoken_generator_fn, + lazy_fixture("custom_ssl_context"), + id="User-Provided SAS Token Auth + Custom SSLContext", + ), + pytest.param(None, None, lazy_fixture("custom_ssl_context"), id="Custom SSLContext Auth"), +] +# # Just the parameters where SAS auth is used +create_auth_params_sas = [param for param in create_auth_params if "SAS" in param.id] +# Just the parameters where a Shared Access Key auth is used +create_auth_params_sak = [param for param in create_auth_params if param.values[0] is not None] +# Just the parameters where SAS callback auth is used +create_auth_params_token_cb = [param for param in create_auth_params if param.values[1] is not None] +# Just the parameters where a custom SSLContext is provided +create_auth_params_custom_ssl = [ + param for param in create_auth_params if param.values[2] is not None +] +# Just the parameters where a custom SSLContext is NOT provided +create_auth_params_default_ssl = [param for param in create_auth_params if param.values[2] is None] + + +# Covers all option kwargs shared across client factory methods +factory_kwargs = [ + # pytest.param("auto_reconnect", False, id="auto_reconnect"), + pytest.param("keep_alive", 34, id="keep_alive"), + pytest.param( + "proxy_options", config.ProxyOptions("HTTP", "fake.address", 1080), id="proxy_options" + ), + pytest.param("websockets", True, id="websockets"), +] + +sk_sm_create_exceptions = [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] + +json_serializable_payload_params = [ + pytest.param("String payload", id="String Payload"), + pytest.param(123, id="Integer Payload"), + pytest.param(2.0, id="Float Payload"), + pytest.param(True, id="Boolean Payload"), + pytest.param({"dictionary": {"payload": "nested"}}, id="Dictionary Payload"), + pytest.param([1, 2, 3], id="List Payload"), + pytest.param((1, 2, 3), id="Tuple Payload"), + pytest.param(None, id="No Payload"), +] + + +@pytest.mark.describe("ProvisioningSession -- Instantiation") +class TestProvisioningSessionInstantiation: + @pytest.mark.it( + "Instantiates and stores a SasTokenProvider that uses symmetric key-based token generation, if `shared_access_key` is provided" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params_sak) + async def test_sak_auth(self, mocker, shared_access_key, sastoken_fn, ssl_context): + assert sastoken_fn is None + spy_sk_sm_cls = mocker.spy(sm, "SymmetricKeySigningMechanism") + spy_st_generator_cls = mocker.spy(st, "InternalSasTokenGenerator") + spy_st_provider_cls = mocker.spy(st, "SasTokenProvider") + expected_uri = get_expected_uri() + + session = ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + ) + + # SymmetricKeySigningMechanism was created from the shared access key + assert spy_sk_sm_cls.call_count == 1 + assert spy_sk_sm_cls.call_args == mocker.call(shared_access_key) + # InternalSasTokenGenerator was created from the SymmetricKeySigningMechanism + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call( + signing_mechanism=spy_sk_sm_cls.spy_return, uri=expected_uri, ttl=3600 + ) + # SasTokenProvider was created from the InternalSasTokenGenerator + assert spy_st_provider_cls.call_count == 1 + assert spy_st_provider_cls.call_args == mocker.call(spy_st_generator_cls.spy_return) + # SasTokenProvider was set on the Session + assert session._sastoken_provider is spy_st_provider_cls.spy_return + + @pytest.mark.it( + "Instantiates and stores a SasTokenProvider that uses callback-based token generation, if `sastoken_fn` is provided" + ) + @pytest.mark.parametrize( + "shared_access_key, sastoken_fn, ssl_context", create_auth_params_token_cb + ) + async def test_token_callback_auth(self, mocker, shared_access_key, sastoken_fn, ssl_context): + assert shared_access_key is None + spy_st_generator_cls = mocker.spy(st, "ExternalSasTokenGenerator") + spy_st_provider_cls = mocker.spy(st, "SasTokenProvider") + + session = ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + sastoken_fn=sastoken_fn, + ) + + # ExternalSasTokenGenerator was created from the sastoken_fn + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call(sastoken_fn) + # SasTokenProvider was created from the ExternalSasTokenGenerator + assert spy_st_provider_cls.call_count == 1 + assert spy_st_provider_cls.call_args == mocker.call(spy_st_generator_cls.spy_return) + # SasTokenProvider was set on the Session + assert session._sastoken_provider is spy_st_provider_cls.spy_return + + @pytest.mark.it( + "Does not instantiate or store any SasTokenProvider if neither `shared_access_key` nor `sastoken_fn` are provided" + ) + async def test_non_sas_auth(self, mocker, custom_ssl_context): + spy_st_provider_cls = mocker.spy(st, "SasTokenProvider") + + session = ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + ssl_context=custom_ssl_context, + ) + + # No SasTokenProvider + assert session._sastoken_provider is None + assert spy_st_provider_cls.call_count == 0 + + @pytest.mark.it( + "Instantiates and stores an ProvisioningMQTTClient, using a new ProvisioningClientConfig object" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_mqtt_client(self, mocker, shared_access_key, sastoken_fn, ssl_context): + spy_config_cls = mocker.spy(config, "ProvisioningClientConfig") + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + assert spy_config_cls.call_count == 0 + assert spy_mqtt_cls.call_count == 0 + + session = ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + sastoken_fn=sastoken_fn, + ) + + assert spy_config_cls.call_count == 1 + assert spy_mqtt_cls.call_count == 1 + assert spy_mqtt_cls.call_args == mocker.call(spy_config_cls.spy_return) + assert session._mqtt_client is spy_mqtt_cls.spy_return + + @pytest.mark.it( + "Sets the provided `provisioning_endpoint` as the `hostname` on the ProvisioningClientConfig used to create the ProvisioningMQTTClient, if `provisioning_endpoint` is provided" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_custom_endpoint(self, mocker, shared_access_key, sastoken_fn, ssl_context): + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + + ProvisioningSession( + provisioning_endpoint=FAKE_HOSTNAME, + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + sastoken_fn=sastoken_fn, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.hostname == FAKE_HOSTNAME + + @pytest.mark.it( + "Sets the Global Provisioning Endpoint as the `hostname` on the ProvisioningClientConfig used to create the ProvisioningMQTTClient, if no `provisioning_endpoint` is provided" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_default_endpoint(self, mocker, shared_access_key, sastoken_fn, ssl_context): + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + + ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + sastoken_fn=sastoken_fn, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.hostname == constant.PROVISIONING_GLOBAL_ENDPOINT + + @pytest.mark.it( + "Sets the provided `registration_id` and `id_scope` values on the ProvisioningClientConfig used to create the ProvisioningMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_ids(self, mocker, shared_access_key, sastoken_fn, ssl_context): + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + + ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + sastoken_fn=sastoken_fn, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.registration_id == FAKE_REGISTRATION_ID + assert cfg.id_scope == FAKE_ID_SCOPE + + @pytest.mark.it( + "Sets the provided `ssl_context` on the ProvisioningClientConfig used to create the ProvisioningMQTTClient, if provided" + ) + @pytest.mark.parametrize( + "shared_access_key, sastoken_fn, ssl_context", create_auth_params_custom_ssl + ) + async def test_custom_ssl_context(self, mocker, shared_access_key, sastoken_fn, ssl_context): + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + + ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + ssl_context=ssl_context, + sastoken_fn=sastoken_fn, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.ssl_context is ssl_context + + @pytest.mark.it( + "Sets a default SSLContext on the ProvisioningClientConfig used to create the ProvisioningMQTTClient, if `ssl_context` is not provided" + ) + @pytest.mark.parametrize( + "shared_access_key, sastoken_fn, ssl_context", create_auth_params_default_ssl + ) + async def test_default_ssl_context(self, mocker, shared_access_key, sastoken_fn, ssl_context): + assert ssl_context is None + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + my_ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + original_ssl_ctx_cls = ssl.SSLContext + + # NOTE: SSLContext is difficult to mock as an entire class, due to how it implements + # instantiation. Essentially, if you mock the entire class, it will not be able to + # instantiate due to an internal reference to the class type, which of course has now been + # changed to MagicMock. To get around this, we mock the class with a side effect that can + # check the arguments passed to the constructor, return a pre-existing SSLContext, and then + # unset the mock to prevent future issues. + def return_and_reset(*args, **kwargs): + ssl.SSLContext = original_ssl_ctx_cls + assert kwargs["protocol"] is ssl.PROTOCOL_TLS_CLIENT + return my_ssl_context + + mocker.patch.object(ssl, "SSLContext", side_effect=return_and_reset) + mocker.spy(my_ssl_context, "load_default_certs") + + ProvisioningSession( + provisioning_endpoint=FAKE_HOSTNAME, + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + ctx = cfg.ssl_context + assert ctx is my_ssl_context + # NOTE: ctx protocol is checked in the `return_and_reset` side effect above + assert ctx.verify_mode == ssl.CERT_REQUIRED + assert ctx.check_hostname is True + assert ctx.load_default_certs.call_count == 1 + assert ctx.load_default_certs.call_args == mocker.call() + assert ctx.minimum_version == ssl.TLSVersion.TLSv1_2 + + @pytest.mark.it( + "Sets the stored SasTokenProvider (if any) on the ProvisioningClientConfig used to create the ProvisioningMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_sastoken_provider_cfg(self, mocker, shared_access_key, sastoken_fn, ssl_context): + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + + session = ProvisioningSession( + provisioning_endpoint=FAKE_HOSTNAME, + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.sastoken_provider is session._sastoken_provider + + @pytest.mark.it( + "Sets `auto_reconnect` to False on the ProvisioningClientConfig used to create the ProvisioningMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_auto_reconnect_cfg(self, mocker, shared_access_key, sastoken_fn, ssl_context): + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + + ProvisioningSession( + provisioning_endpoint=FAKE_HOSTNAME, + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert cfg.auto_reconnect is False + + @pytest.mark.it( + "Sets any provided optional keyword arguments on the ProvisioningClientConfig used to create the ProvisioningMQTTClient" + ) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs( + self, + mocker, + shared_access_key, + sastoken_fn, + ssl_context, + kwarg_name, + kwarg_value, + ): + spy_mqtt_cls = mocker.spy(mqtt, "ProvisioningMQTTClient") + kwargs = {kwarg_name: kwarg_value} + + ProvisioningSession( + provisioning_endpoint=FAKE_HOSTNAME, + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + **kwargs + ) + + cfg = spy_mqtt_cls.call_args[0][0] + assert getattr(cfg, kwarg_name) == kwarg_value + + @pytest.mark.it("Sets the `wait_for_disconnect_task` attribute to None") + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_wait_for_disconnect_task(self, shared_access_key, sastoken_fn, ssl_context): + session = ProvisioningSession( + provisioning_endpoint=FAKE_HOSTNAME, + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + assert session._wait_for_disconnect_task is None + + @pytest.mark.it( + "Raises ValueError if neither `shared_access_key`, `sastoken_fn` nor `ssl_context` are provided as parameters" + ) + async def test_no_auth(self): + with pytest.raises(ValueError): + ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + provisioning_endpoint=FAKE_HOSTNAME, + ) + + @pytest.mark.it( + "Raises ValueError if both `shared_access_key` and `sastoken_fn` are provided as parameters" + ) + async def test_conflicting_auth(self, optional_ssl_context): + with pytest.raises(ValueError): + ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + provisioning_endpoint=FAKE_HOSTNAME, + shared_access_key=FAKE_SHARED_ACCESS_KEY, + sastoken_fn=sastoken_generator_fn, + ssl_context=optional_ssl_context, + ) + + @pytest.mark.it("Raises TypeError if an invalid keyword argument is provided") + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params) + async def test_bad_kwarg(self, shared_access_key, sastoken_fn, ssl_context): + with pytest.raises(TypeError): + ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + provisioning_endpoint=FAKE_HOSTNAME, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + invalid_argument="some value", + ) + + @pytest.mark.it( + "Allows any exceptions raised when creating a SymmetricKeySigningMechanism to propagate" + ) + @pytest.mark.parametrize("exception", sk_sm_create_exceptions) + @pytest.mark.parametrize("shared_access_key, sastoken_fn, ssl_context", create_auth_params_sak) + async def test_sksm_raises( + self, mocker, shared_access_key, sastoken_fn, ssl_context, exception + ): + mocker.patch.object(sm, "SymmetricKeySigningMechanism", side_effect=exception) + assert sastoken_fn is None + + with pytest.raises(type(exception)) as e_info: + ProvisioningSession( + registration_id=FAKE_REGISTRATION_ID, + id_scope=FAKE_ID_SCOPE, + provisioning_endpoint=FAKE_HOSTNAME, + shared_access_key=shared_access_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + +@pytest.mark.describe("ProvisioningSession -- Context Manager Usage") +class TestProvisioningSessionContextManager: + @pytest.fixture + def session(self, disconnected_session): + return disconnected_session + + @pytest.mark.it( + "Starts the ProvisioningMQTTClient upon entry into the context manager, and stops it upon exit" + ) + async def test_mqtt_client_start_stop(self, session): + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.stop.await_count == 0 + + async with session as session: + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 0 + + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 1 + + @pytest.mark.it( + "Stops the ProvisioningMQTTClient upon exit, even if an error was raised within the block inside the context manager" + ) + async def test_mqtt_client_start_stop_with_failure(self, session, arbitrary_exception): + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.stop.await_count == 0 + + try: + async with session as session: + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.start.await_count == 1 + assert session._mqtt_client.stop.await_count == 1 + + @pytest.mark.it( + "Connect the ProvisioningMQTTClient upon entry into the context manager, and disconnect it upon exit" + ) + async def test_mqtt_client_connection(self, session): + assert session._mqtt_client.connect.await_count == 0 + assert session._mqtt_client.disconnect.await_count == 0 + + async with session as session: + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 0 + + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 1 + + @pytest.mark.it( + "Disconnect the ProvisioningMQTTClient upon exit, even if an error was raised within the block inside the context manager" + ) + async def test_mqtt_client_connection_with_failure(self, session, arbitrary_exception): + assert session._mqtt_client.connect.await_count == 0 + assert session._mqtt_client.disconnect.await_count == 0 + + try: + async with session as session: + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._mqtt_client.connect.await_count == 1 + assert session._mqtt_client.disconnect.await_count == 1 + + @pytest.mark.it( + "Starts the SasTokenProvider upon entry into the context manager, and stops it upon exit, if one exists" + ) + async def test_sastoken_provider_start_stop(self, session, mock_sastoken_provider): + session._sastoken_provider = mock_sastoken_provider + assert session._sastoken_provider.start.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + async with session as session: + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 0 + + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + @pytest.mark.it( + "Stops the SasTokenProvider upon exit, even if an error was raised within the block inside the context manager" + ) + async def test_sastoken_provider_start_stop_with_failure( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + assert session._sastoken_provider.start.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + try: + async with session as session: + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 0 + raise arbitrary_exception + except type(arbitrary_exception): + pass + + assert session._sastoken_provider.start.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + @pytest.mark.it("Can handle the case where SasTokenProvider does not exist") + async def test_no_sastoken_provider(self, session): + assert session._sastoken_provider is None + + async with session as session: + pass + + # If nothing raises here, the test passes + + @pytest.mark.it( + "Creates a Task from the MQTTClient's .wait_for_disconnect() coroutine method and stores it as the `wait_for_disconnect_task` attribute upon entry into the context manager, and cancels and clears the Task upon exit" + ) + async def test_wait_for_disconnect_task(self, mocker, session): + assert session._wait_for_disconnect_task is None + assert session._mqtt_client.wait_for_disconnect.call_count == 0 + + async with session as session: + # Task Created and Method called + assert isinstance(session._wait_for_disconnect_task, asyncio.Task) + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.call_args == mocker.call() + await asyncio.sleep(0.1) + assert session._mqtt_client.wait_for_disconnect.is_hanging() + # Returning method completes task (thus task corresponds to method) + session._mqtt_client.wait_for_disconnect.stop_hanging() + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task.done() + assert ( + session._wait_for_disconnect_task.result() + is session._mqtt_client.wait_for_disconnect.return_value + ) + # Replace the task with a mock so we can show it is cancelled/cleared on exit + mock_task = mocker.MagicMock() + session._wait_for_disconnect_task = mock_task + assert mock_task.cancel.call_count == 0 + + # Mocked task was cancelled and cleared + assert mock_task.cancel.call_count == 1 + assert session._wait_for_disconnect_task is None + + @pytest.mark.it( + "Cancels and clears the `wait_for_disconnect_task` Task, even if an error was raised within the block inside the context manager" + ) + async def test_wait_for_disconnect_task_with_failure(self, session, arbitrary_exception): + assert session._wait_for_disconnect_task is None + + try: + async with session as session: + task = session._wait_for_disconnect_task + assert task is not None + assert not task.done() + raise arbitrary_exception + except type(arbitrary_exception): + pass + + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert task.done() + assert task.cancelled() + + @pytest.mark.it( + "Allows any errors raised within the block inside the context manager to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + # NOTE: it is important to test the CancelledError since it is a regular Exception in 3.7, + # but a BaseException from 3.8+ + pytest.param(asyncio.CancelledError(), id="CancelledError"), + ], + ) + async def test_error_propagation(self, session, exception): + with pytest.raises(type(exception)) as e_info: + async with session as session: + raise exception + assert e_info.value is exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while starting the SasTokenProvider during context manager entry to propagate" + ) + async def test_enter_sastoken_provider_start_raises( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.start.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Does not start or connect the ProvisioningMQTTClient, nor create the `wait_for_disconnect_task`, if an error is raised while starting the SasTokenProvider during context manager entry" + ) + async def test_enter_sastoken_provider_start_raises_cleanup( + self, mocker, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.start.side_effect = arbitrary_exception + assert session._sastoken_provider.start.await_count == 0 + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.connect.await_count == 0 + assert session._wait_for_disconnect_task is None + spy_create_task = mocker.spy(asyncio, "create_task") + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + assert session._sastoken_provider.start.await_count == 1 + assert session._mqtt_client.start.await_count == 0 + assert session._mqtt_client.connect.await_count == 0 + assert session._wait_for_disconnect_task is None + assert spy_create_task.call_count == 0 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while starting the ProvisioningMQTTClient during context manager entry to propagate" + ) + async def test_enter_mqtt_client_start_raises(self, session, arbitrary_exception): + session._mqtt_client.start.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the ProvisioningMQTTClient and SasTokenProvider (if present) that were previously started, and does not create the `wait_for_disconnect_task`, if an error is raised while starting the ProvisioningMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_enter_mqtt_client_start_raises_cleanup( + self, mocker, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.start.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + spy_create_task = mocker.spy(asyncio, "create_task") + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + assert session._wait_for_disconnect_task is None + assert spy_create_task.call_count == 0 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the SasTokenProvider (if present) even if an error was raised while stopping the ProvisioningMQTTClient in response to an error raised while starting the ProvisioningMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_enter_mqtt_client_start_raises_then_mqtt_client_stop_raises( + self, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.start.side_effect = arbitrary_exception + session._mqtt_client.stop.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the ProvisioningMQTTClient even if an error was raised while stopping the SasTokenProvider in response to an error raised while starting the ProvisioningMQTTClient during context manager entry" + ) + async def test_enter_mqtt_client_start_raises_then_sastoken_provider_stop_raises( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._mqtt_client.start.side_effect = arbitrary_exception + session._sastoken_provider.stop.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + @pytest.mark.it( + "Allows any errors raised while connecting with the ProvisioningMQTTClient during context manager entry to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises(self, session, exception): + session._mqtt_client.connect.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + async with session as session: + pass + assert e_info.value is exception + + @pytest.mark.it( + "Stops the ProvisioningMQTTClient and SasTokenProvider (if present) that were previously started, and does not create the `wait_for_disconnect_task`, if an error is raised while connecting during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises_cleanup( + self, mocker, session, sastoken_provider, exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.connect.side_effect = exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + spy_create_task = mocker.spy(asyncio, "create_task") + + with pytest.raises(type(exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + assert session._wait_for_disconnect_task is None + assert spy_create_task.call_count == 0 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the SasTokenProvider (if present) even if an error was raised while stopping the ProvisioningMQTTClient in response to an error raised while connecting the ProvisioningMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises_then_mqtt_client_stop_raises( + self, session, sastoken_provider, exception, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.connect.side_effect = exception # Realistic failure + session._mqtt_client.stop.side_effect = arbitrary_exception # Shouldn't happen + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + + # NOTE: arbitrary_exception is raised here instead of exception - this is because it + # happened second, during resolution of exception, thus taking precedence + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the ProvisioningMQTTClient even if an error was raised while stopping the SasTokenProvider in response to an error raised while connecting the ProvisioningMQTTClient during context manager entry" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.MQTTConnectionFailedError(), id="MQTTConnectionFailedError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_enter_mqtt_client_connect_raises_then_sastoken_provider_stop_raises( + self, session, mock_sastoken_provider, exception, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._mqtt_client.connect.side_effect = exception # Realistic failure + session._sastoken_provider.stop.side_effect = arbitrary_exception # Shouldn't happen + assert session._mqtt_client.stop.await_count == 0 + assert session._sastoken_provider.stop.await_count == 0 + + # NOTE: arbitrary_exception is raised here instead of exception - this is because it + # happened second, during resolution of exception, thus taking precedence + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + assert session._mqtt_client.stop.await_count == 1 + assert session._sastoken_provider.stop.await_count == 1 + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while disconnecting the ProvisioningMQTTClient during context manager exit to propagate" + ) + async def test_exit_disconnect_raises(self, session, arbitrary_exception): + session._mqtt_client.disconnect.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Stops the ProvisioningMQTTClient and SasTokenProvider (if present), and cancels and clears the `wait_for_disconnect_task`, even if an error was raised while disconnecting the ProvisioningMQTTClient during context manager exit" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_exit_disconnect_raises_cleanup( + self, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.disconnect.side_effect = arbitrary_exception + assert session._mqtt_client.stop.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + conn_drop_task = session._wait_for_disconnect_task + assert not conn_drop_task.done() + + assert session._mqtt_client.stop.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert conn_drop_task.cancelled() + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while stopping the ProvisioningMQTTClient during context manager exit to propagate" + ) + async def test_exit_mqtt_client_stop_raises(self, session, arbitrary_exception): + session._mqtt_client.stop.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)) as e_info: + async with session as session: + pass + assert e_info.value is arbitrary_exception + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Disconnects the ProvisioningMQTTClient and stops the SasTokenProvider (if present), and cancels and clears the `wait_for_disconnect_task`, even if an error was raised while stopping the ProvisioningMQTTClient during context manager exit" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"), + pytest.param(None, id="SasTokenProvider not present"), + ], + ) + async def test_exit_mqtt_client_stop_raises_cleanup( + self, session, sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = sastoken_provider + session._mqtt_client.stop.side_effect = arbitrary_exception + assert session._mqtt_client.disconnect.await_count == 0 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + conn_drop_task = session._wait_for_disconnect_task + assert not conn_drop_task.done() + + assert session._mqtt_client.disconnect.await_count == 1 + if session._sastoken_provider: + assert session._sastoken_provider.stop.await_count == 1 + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert conn_drop_task.cancelled() + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Allows any errors raised while stopping the SasTokenProvider during context manager exit to propagate" + ) + async def test_exit_sastoken_provider_stop_raises( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.stop.side_effect = arbitrary_exception + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + pass + + # NOTE: This shouldn't happen, but we test it anyway + @pytest.mark.it( + "Disconnects and stops the ProvisioningMQTTClient, and cancels and clears the `wait_for_disconnect_task`, even if an error was raised while stopping the SasTokenProvider during context manager exit" + ) + async def test_exit_sastoken_provider_stop_raises_cleanup( + self, session, mock_sastoken_provider, arbitrary_exception + ): + session._sastoken_provider = mock_sastoken_provider + session._sastoken_provider.stop.side_effect = arbitrary_exception + assert session._mqtt_client.disconnect.await_count == 0 + assert session._mqtt_client.stop.await_count == 0 + assert session._wait_for_disconnect_task is None + + with pytest.raises(type(arbitrary_exception)): + async with session as session: + conn_drop_task = session._wait_for_disconnect_task + assert not conn_drop_task.done() + + assert session._mqtt_client.disconnect.await_count == 1 + assert session._mqtt_client.stop.await_count == 1 + await asyncio.sleep(0.1) + assert session._wait_for_disconnect_task is None + assert conn_drop_task.cancelled() + + # TODO: consider adding detailed cancellation tests + # Not sure how cancellation would work in a context manager situation, needs more investigation + + +@pytest.mark.describe("ProvisioningSession - .register()") +class TestProvisioningSessionRegister: + @pytest.mark.it( + "Invokes .send_register() on the ProvisioningMQTTClient, passing the provided `payload`, if provided" + ) + @pytest.mark.parametrize("payload", json_serializable_payload_params) + async def test_invoke_with_payload(self, mocker, session, payload): + assert session._mqtt_client.send_register.await_count == 0 + + await session.register(payload) + + assert session._mqtt_client.send_register.await_count == 1 + assert session._mqtt_client.send_register.await_args == mocker.call(payload) + + @pytest.mark.it( + "Invokes .send_register() on the ProvisioningMQTTClient, passing None, if no `payload` is provided" + ) + async def test_invoke_no_payload(self, mocker, session): + assert session._mqtt_client.send_register.await_count == 0 + + await session.register() + + assert session._mqtt_client.send_register.await_count == 1 + assert session._mqtt_client.send_register.await_args == mocker.call(None) + + @pytest.mark.it("Allows any exceptions raised by the ProvisioningMQTTClient to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(exc.ProvisioningServiceError(), id="ProvisioningServiceError"), + pytest.param(exc.MQTTError(5), id="MQTTError"), + pytest.param(asyncio.CancelledError(), id="CancelledError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Error"), + ], + ) + @pytest.mark.parametrize("payload", json_serializable_payload_params) + async def test_mqtt_client_raises(self, session, exception, payload): + session._mqtt_client.send_register.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await session.register(payload) + # CancelledError doesn't propagate in some versions of Python + # TODO: determine which versions exactly + if not isinstance(exception, asyncio.CancelledError): + assert e_info.value is exception + + @pytest.mark.it( + "Raises SessionError without invoking .register() on the ProvisioningMQTTClient if it is not connected" + ) + @pytest.mark.parametrize("payload", json_serializable_payload_params) + async def test_not_connected(self, mocker, session, payload): + conn_property_mock = mocker.PropertyMock(return_value=False) + type(session._mqtt_client).connected = conn_property_mock + + with pytest.raises(exc.SessionError): + await session.register(payload) + assert session._mqtt_client.send_register.call_count == 0 + + @pytest.mark.it( + "Raises CancelledError if an expected disconnect occurs in the ProvisioningMQTTClient while waiting for the operation to complete" + ) + @pytest.mark.parametrize("payload", json_serializable_payload_params) + async def test_expected_disconnect_during_send(self, session, payload): + session._mqtt_client.send_register = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.register(payload)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_register.wait_for_hang() + assert not t.done() + + # No disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate expected disconnect + session._mqtt_client.wait_for_disconnect.return_value = None + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it( + "Raises the MQTTConnectionDroppedError that caused the unexpected disconnect, if an unexpected disconnect occurs in the " + "ProvisioningMQTTClient while waiting for the operation to complete" + ) + @pytest.mark.parametrize("payload", json_serializable_payload_params) + async def test_unexpected_disconnect_during_send(self, session, payload): + session._mqtt_client.send_register = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.register(payload)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_register.wait_for_hang() + assert not t.done() + + # No unexpected disconnect yet + assert not session._wait_for_disconnect_task.done() + assert session._mqtt_client.wait_for_disconnect.call_count == 1 + assert session._mqtt_client.wait_for_disconnect.is_hanging() + + # Simulate unexpected disconnect + cause = exc.MQTTConnectionDroppedError(rc=7) + session._mqtt_client.wait_for_disconnect.return_value = cause + session._mqtt_client.wait_for_disconnect.stop_hanging() + + with pytest.raises(exc.MQTTConnectionDroppedError) as e_info: + await t + assert e_info.value is cause + + @pytest.mark.it( + "Can be cancelled while waiting for the ProvisioningMQTTClient operation to complete" + ) + @pytest.mark.parametrize("payload", json_serializable_payload_params) + async def test_cancel_during_send(self, session, payload): + session._mqtt_client.send_register = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(session.register(payload)) + + # Hanging, waiting for send to finish + await session._mqtt_client.send_register.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t diff --git a/tests/unit/test_request_response.py b/tests/unit/test_request_response.py new file mode 100644 index 000000000..7593ace02 --- /dev/null +++ b/tests/unit/test_request_response.py @@ -0,0 +1,258 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import asyncio +import pytest +import uuid +from azure.iot.device.request_response import Request, Response, RequestLedger + +fake_request_id = str(uuid.uuid4()) +fake_status = 200 +fake_body = "{'data' : 'value'}" +fake_props = { + "fake_key_1": "fake_val_1", + "fake_key_2": "fake_val_2", +} + + +@pytest.mark.describe("Response") +class TestResponse: + @pytest.mark.it( + "Instantiates with the provided request_id, status, body and properties stored as attributes" + ) + def test_attributes(self): + r = Response( + request_id=fake_request_id, status=fake_status, body=fake_body, properties=fake_props + ) + assert r.request_id == fake_request_id + assert r.status == fake_status + assert r.body == fake_body + assert r.properties == fake_props + + +@pytest.mark.describe("Request") +class TestRequest: + @pytest.mark.it( + "Instantiates with a generated UUID (in string format) as the `request_id` attribute, if no `request_id` is provided" + ) + async def test_request_id_attr_generated(self, mocker): + uuid4_mock = mocker.patch.object(uuid, "uuid4") + r = Request() + assert uuid4_mock.call_count == 1 + assert r.request_id == str(uuid4_mock.return_value) + + @pytest.mark.it( + "Instantiates with the `request_id` attribute set to the provided `request_id`, if it is provided" + ) + async def test_request_id_provided(self, mocker): + my_request_id = "3226c2f7-3d30-425c-b83b-0c34335f8220" + r = Request(request_id=my_request_id) + assert r.request_id == my_request_id + + @pytest.mark.it( + "Instantiates with an incomplete async Future as the `response_future` attribute" + ) + async def test_response_future_attr(self): + r = Request() + assert isinstance(r.response_future, asyncio.Future) + assert not r.response_future.done() + + @pytest.mark.it( + "Awaits and returns the result of the `response_future` when `.get_response()` is invoked" + ) + async def test_get_response(self): + req = Request() + assert not req.response_future.done() + + # .get_response() doesn't return yet + task = asyncio.create_task(req.get_response()) + await asyncio.sleep(0.1) + assert not task.done() + + # Add a result to the future so the task will complete + resp = Response(request_id=req.request_id, status=fake_status, body=fake_body) + req.response_future.set_result(resp) + result = await task + assert result is resp + + +@pytest.mark.describe("RequestLedger") +class TestRequestLedger: + @pytest.fixture + async def ledger(self): + return RequestLedger() + + @pytest.mark.it("Instantiates with an empty dictionary of pending requests") + async def test_initial_ledger(self): + ledger = RequestLedger() + assert isinstance(ledger.pending, dict) + assert len(ledger.pending) == 0 + + @pytest.mark.it( + "Creates and returns a new Request, tracking it in the pending requests dictionary, for each invocation of .create_request()" + ) + async def test_create_request(self, ledger): + assert len(ledger.pending) == 0 + req1 = await ledger.create_request() + assert len(ledger.pending) == 1 + req2 = await ledger.create_request() + assert len(ledger.pending) == 2 + req3 = await ledger.create_request() + assert len(ledger.pending) == 3 + + assert ledger.pending[req1.request_id] == req1.response_future + assert ledger.pending[req2.request_id] == req2.response_future + assert ledger.pending[req3.request_id] == req3.response_future + + @pytest.mark.it( + "New Requests can have their request_id provided when invoking .create_request()" + ) + async def test_create_request_custom_id(self, ledger): + assert len(ledger.pending) == 0 + request_id_1 = "3226c2f7-3d30-425c-b83b-0c34335f8220" + req1 = await ledger.create_request(request_id=request_id_1) + assert len(ledger.pending) == 1 + request_id_2 = "d9d7ce4d-3be9-498b-abde-913b81b880e5" + req2 = await ledger.create_request(request_id=request_id_2) + assert len(ledger.pending) == 2 + + assert ledger.pending[req1.request_id] == req1.response_future + assert ledger.pending[req2.request_id] == req2.response_future + assert req1.request_id == request_id_1 + assert req2.request_id == request_id_2 + + @pytest.mark.it( + "Raises ValueError if the request id provided via an invocation of .create_request() is already being tracked" + ) + async def test_create_duplicate_id(self, ledger): + req_id = "3226c2f7-3d30-425c-b83b-0c34335f8220" + req = await ledger.create_request(request_id=req_id) + assert req.request_id in ledger.pending + assert req.request_id == req_id + + with pytest.raises(ValueError): + await ledger.create_request(request_id=req_id) + + @pytest.mark.it( + "Removes a tracked Request from the ledger that matches the request id provided via an invocation of .delete_request()" + ) + async def test_delete_request(self, ledger): + req1 = await ledger.create_request() + req2 = await ledger.create_request() + assert req1.request_id in ledger.pending + assert req2.request_id in ledger.pending + + await ledger.delete_request(req2.request_id) + assert req2.request_id not in ledger.pending + assert req1.request_id in ledger.pending + await ledger.delete_request(req1.request_id) + assert req1.request_id not in ledger.pending + + @pytest.mark.it( + "Raises a KeyError if the request id provided to an invocation of .delete_request() does not match one in the ledger" + ) + async def test_delete_request_bad_id(self, ledger): + req1 = await ledger.create_request() + assert req1.request_id in ledger.pending + await ledger.delete_request(req1.request_id) + assert req1.request_id not in ledger.pending + + with pytest.raises(KeyError): + await ledger.delete_request(req1.request_id) + + @pytest.mark.it( + "Completes a tracked Request and removes it from the Ledger when a Response that matches its request id is provided via an invocation of .match_response()" + ) + async def test_match_response(self, ledger): + assert len(ledger.pending) == 0 + req1 = await ledger.create_request() + assert len(ledger.pending) == 1 + req2 = await ledger.create_request() + assert len(ledger.pending) == 2 + req3 = await ledger.create_request() + assert len(ledger.pending) == 3 + + gr_task1 = asyncio.create_task(req1.get_response()) + gr_task2 = asyncio.create_task(req2.get_response()) + gr_task3 = asyncio.create_task(req3.get_response()) + await asyncio.sleep(0.1) + assert not gr_task1.done() + assert not gr_task2.done() + assert not gr_task3.done() + + resp1 = Response(request_id=req1.request_id, status=fake_status, body=fake_body) + resp2 = Response(request_id=req2.request_id, status=fake_status, body=fake_body) + resp3 = Response(request_id=req3.request_id, status=fake_status, body=fake_body) + + await ledger.match_response(resp2) + assert len(ledger.pending) == 2 + assert req2.request_id not in ledger.pending + assert await gr_task2 is resp2 + + await ledger.match_response(resp3) + assert len(ledger.pending) == 1 + assert req3.request_id not in ledger.pending + assert await gr_task3 is resp3 + + await ledger.match_response(resp1) + assert len(ledger.pending) == 0 + assert req1.request_id not in ledger.pending + assert await gr_task1 is resp1 + + @pytest.mark.it( + "Raises a KeyError if the Response provided to an invocation of .match_response() does not have a request id that matches any tracked Request" + ) + async def test_match_response_bad_id(self, ledger): + req1 = await ledger.create_request() + assert req1.request_id in ledger.pending + await ledger.delete_request(req1.request_id) + assert req1.request_id not in ledger.pending + + resp1 = Response(request_id=req1.request_id, status=fake_status, body=fake_body) + + with pytest.raises(KeyError): + await ledger.match_response(resp1) + + @pytest.mark.it( + "Implements support for len() by returning the number of pending items in the ledger" + ) + async def test_len(self, ledger): + assert len(ledger.pending) == 0 + assert len(ledger) == 0 + + req1 = await ledger.create_request() + assert len(ledger.pending) == len(ledger) == 1 + req2 = await ledger.create_request() + assert len(ledger.pending) == len(ledger) == 2 + + resp1 = Response(request_id=req1.request_id, status=fake_status, body=fake_body) + await ledger.match_response(resp1) + assert len(ledger.pending) == len(ledger) == 1 + resp2 = Response(request_id=req2.request_id, status=fake_status, body=fake_body) + await ledger.match_response(resp2) + assert len(ledger.pending) == len(ledger) == 0 + + @pytest.mark.it("Implements support for identifying if a request_id is currently pending") + async def test_contains(self, ledger): + assert len(ledger) == 0 + + req1 = await ledger.create_request() + assert len(ledger) == 1 + assert req1.request_id in ledger + req2 = await ledger.create_request() + assert len(ledger) == 2 + assert req2.request_id in ledger + + # Remove req1 from ledger by matching + resp = Response(request_id=req1.request_id, status=fake_status, body=fake_body) + await ledger.match_response(resp) + assert len(ledger) == 1 + assert req1.request_id not in ledger + assert req2.request_id in ledger + + # Remove req2 from ledger by deletion + await ledger.delete_request(req2.request_id) + assert len(ledger) == 0 + assert req2.request_id not in ledger diff --git a/tests/unit/test_sastoken.py b/tests/unit/test_sastoken.py new file mode 100644 index 000000000..8b75c144c --- /dev/null +++ b/tests/unit/test_sastoken.py @@ -0,0 +1,785 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import logging +import pytest +import sys +import time +import urllib.parse +from pytest_lazyfixture import lazy_fixture +from azure.iot.device.sastoken import ( + SasToken, + InternalSasTokenGenerator, + ExternalSasTokenGenerator, + SasTokenProvider, + SasTokenError, + TOKEN_FORMAT, + DEFAULT_TOKEN_UPDATE_MARGIN, +) +from azure.iot.device import sastoken as st + +logging.basicConfig(level=logging.DEBUG) + +FAKE_URI = "some/resource/location" +FAKE_SIGNED_DATA = "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=" +FAKE_SIGNED_DATA2 = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" +FAKE_CURRENT_TIME = 10000000000.0 # We living in 2286! + + +def token_parser(token_str): + """helper function that parses a token string for individual values""" + token_map = {} + kv_string = token_str.split(" ")[1] + kv_pairs = kv_string.split("&") + for kv in kv_pairs: + t = kv.split("=") + token_map[t[0]] = t[1] + return token_map + + +def get_expiry_time(): + return int(time.time()) + 3600 # One hour from right now, + + +@pytest.fixture +def sastoken_str(): + return TOKEN_FORMAT.format( + resource=urllib.parse.quote(FAKE_URI, safe=""), + signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""), + expiry=get_expiry_time(), + ) + + +@pytest.fixture +def sastoken(sastoken_str): + return SasToken(sastoken_str) + + +@pytest.fixture +async def mock_signing_mechanism(mocker): + mock_sm = mocker.AsyncMock() + mock_sm.sign.return_value = FAKE_SIGNED_DATA + return mock_sm + + +@pytest.fixture(params=["Generator Function", "Generator Coroutine Function"]) +def mock_token_generator_fn(mocker, request, sastoken_str): + if request.param == "Function": + return mocker.MagicMock(return_value=sastoken_str) + else: + return mocker.AsyncMock(return_value=sastoken_str) + + +@pytest.fixture(params=["InternalSasTokenGenerator", "ExternalSasTokenGenerator"]) +def sastoken_generator(request, mocker, mock_signing_mechanism, sastoken_str): + if request.param == "ExternalSasTokenGenerator": + # We don't care about the difference between sync/async generator_fns when testing + # at this level of abstraction, so just pick one + generator = ExternalSasTokenGenerator(mocker.MagicMock(return_value=sastoken_str)) + else: + generator = InternalSasTokenGenerator(mock_signing_mechanism, FAKE_URI) + mocker.spy(generator, "generate_sastoken") + return generator + + +@pytest.fixture +async def sastoken_provider(sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + await provider.start() + # Creating from the generator invokes a call on the generator, so reset the spy mock + # so it doesn't throw off any testing logic + provider._generator.generate_sastoken.reset_mock() + yield provider + await provider.stop() + + +@pytest.mark.describe("SasToken") +class TestSasToken: + @pytest.mark.it("Instantiates from a valid SAS Token string") + def test_instantiates_from_token_string(self, sastoken_str): + s = SasToken(sastoken_str) + assert s._token_str == sastoken_str + + @pytest.mark.it("Raises a ValueError error if instantiating from an invalid SAS Token string") + @pytest.mark.parametrize( + "invalid_token_str", + [ + pytest.param( + "sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312", + id="Incomplete token format", + ), + pytest.param( + "SharedERRORSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312", + id="Invalid token format", + ), + pytest.param( + "SharedAccessignature sr=some%2Fresource%2Flocationsig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se12321312", + id="Token values incorectly formatted", + ), + pytest.param( + "SharedAccessSignature sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312", + id="Missing resource value", + ), + pytest.param( + "SharedAccessSignature sr=some%2Fresource%2Flocation&se=12321312", + id="Missing signature value", + ), + pytest.param( + "SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=", + id="Missing expiry value", + ), + ], + ) + def test_raises_error_invalid_token_string(self, invalid_token_str): + with pytest.raises(ValueError): + SasToken(invalid_token_str) + + @pytest.mark.it("Returns the SAS token string as the string representation of the object") + def test_str_rep(self, sastoken_str): + sastoken = SasToken(sastoken_str) + assert str(sastoken) == sastoken_str + + @pytest.mark.it( + "Instantiates with the .expiry_time property corresponding to the expiry time of the given SAS Token string (as a float)" + ) + def test_instantiates_expiry_time(self, sastoken_str): + sastoken = SasToken(sastoken_str) + expected_expiry_time = token_parser(sastoken_str)["se"] + assert sastoken.expiry_time == float(expected_expiry_time) + + @pytest.mark.it("Maintains .expiry_time as a read-only property") + def test_expiry_time_read_only(self, sastoken): + with pytest.raises(AttributeError): + sastoken.expiry_time = 12312312312123 + + @pytest.mark.it( + "Instantiates with the .resource_uri property corresponding to the URL decoded URI of the given SAS Token string" + ) + def test_instantiates_resource_uri(self, sastoken_str): + sastoken = SasToken(sastoken_str) + resource_uri = token_parser(sastoken_str)["sr"] + assert resource_uri != sastoken.resource_uri + assert resource_uri == urllib.parse.quote(sastoken.resource_uri, safe="") + assert urllib.parse.unquote(resource_uri) == sastoken.resource_uri + + @pytest.mark.it("Maintains .resource_uri as a read-only property") + def test_resource_uri_read_only(self, sastoken): + with pytest.raises(AttributeError): + sastoken.resource_uri = "new%2Ffake%2Furi" + + @pytest.mark.it( + "Instantiates with the .signature property corresponding to the URL decoded signature of the given SAS Token string" + ) + def test_instantiates_signature(self, sastoken_str): + sastoken = SasToken(sastoken_str) + signature = token_parser(sastoken_str)["sig"] + assert signature != sastoken.signature + assert signature == urllib.parse.quote(sastoken.signature, safe="") + assert urllib.parse.unquote(signature) == sastoken.signature + + @pytest.mark.it("Maintains .signature as a read-only property") + def test_signature_read_only(self, sastoken): + with pytest.raises(AttributeError): + sastoken.signature = "asdfas" + + +@pytest.mark.describe("InternalSasTokenGenerator -- Instantiation") +class TestSasTokenGeneratorInstantiation: + @pytest.mark.it("Stores the provided signing mechanism as an attribute") + def test_signing_mechanism(self, mock_signing_mechanism): + generator = InternalSasTokenGenerator( + signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700 + ) + assert generator.signing_mechanism is mock_signing_mechanism + + @pytest.mark.it("Stores the provided URI as an attribute") + def test_uri(self, mock_signing_mechanism): + generator = InternalSasTokenGenerator( + signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700 + ) + assert generator.uri == FAKE_URI + + @pytest.mark.it("Stores the provided TTL as an attribute") + def test_ttl(self, mock_signing_mechanism): + generator = InternalSasTokenGenerator( + signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700 + ) + assert generator.ttl == 4700 + + @pytest.mark.it("Defaults to using 3600 as the TTL if not provided") + def test_ttl_default(self, mock_signing_mechanism): + generator = InternalSasTokenGenerator( + signing_mechanism=mock_signing_mechanism, uri=FAKE_URI + ) + assert generator.ttl == 3600 + + +@pytest.mark.describe("InternalSasTokenGenerator - .generate_sastoken()") +class TestSasTokenGeneratorGenerateSastoken: + @pytest.fixture + def sastoken_generator(self, mock_signing_mechanism): + return InternalSasTokenGenerator( + signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700 + ) + + @pytest.mark.it( + "Returns a newly generated SasToken for the configured URI that is valid for TTL seconds" + ) + async def test_token_expiry(self, mocker, sastoken_generator): + # Patch time.time() to return a fake time so that it's easy to check the delta with expiry + mocker.patch.object(time, "time", return_value=FAKE_CURRENT_TIME) + expected_expiry = FAKE_CURRENT_TIME + sastoken_generator.ttl + token = await sastoken_generator.generate_sastoken() + assert isinstance(token, SasToken) + assert token.expiry_time == expected_expiry + assert token.resource_uri == sastoken_generator.uri + assert token._token_info["sr"] == urllib.parse.quote(sastoken_generator.uri, safe="") + assert token.resource_uri != token._token_info["sr"] + + @pytest.mark.it( + "Creates the resulting SasToken's signature by using the InternalSasTokenGenerator's signing mechanism to sign a concatenation of the (URL encoded) URI and (URL encoded, int converted) desired expiry time" + ) + async def test_token_signature(self, mocker, sastoken_generator): + assert sastoken_generator.signing_mechanism.await_count == 0 + mocker.patch.object(time, "time", return_value=FAKE_CURRENT_TIME) + expected_expiry = int(FAKE_CURRENT_TIME + sastoken_generator.ttl) + expected_data_to_sign = ( + urllib.parse.quote(sastoken_generator.uri, safe="") + "\n" + str(expected_expiry) + ) + + token = await sastoken_generator.generate_sastoken() + + assert sastoken_generator.signing_mechanism.sign.await_count == 1 + assert sastoken_generator.signing_mechanism.sign.await_args == mocker.call( + expected_data_to_sign + ) + assert token._token_info["sig"] == urllib.parse.quote( + sastoken_generator.signing_mechanism.sign.return_value, safe="" + ) + assert token.signature == sastoken_generator.signing_mechanism.sign.return_value + assert token.signature != token._token_info["sig"] + + @pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism") + async def test_signing_mechanism_raises(self, sastoken_generator, arbitrary_exception): + sastoken_generator.signing_mechanism.sign.side_effect = arbitrary_exception + + with pytest.raises(SasTokenError) as e_info: + await sastoken_generator.generate_sastoken() + assert e_info.value.__cause__ is arbitrary_exception + + +@pytest.mark.describe("ExternalSasTokenGenerator -- Instantiation") +class TestExternalSasTokenGeneratorInstantiation: + @pytest.mark.it("Stores the provided generator_fn callable as an attribute") + def test_generator_fn_attribute(self, mock_token_generator_fn): + sastoken_generator = ExternalSasTokenGenerator(mock_token_generator_fn) + assert sastoken_generator.generator_fn is mock_token_generator_fn + + +@pytest.mark.describe("ExternalSasTokenGenerator -- .generate_sastoken()") +class TestExternalSasTokenGeneratorGenerateSasToken: + @pytest.fixture + def sastoken_generator(self, mock_token_generator_fn): + return ExternalSasTokenGenerator(mock_token_generator_fn) + + @pytest.mark.it( + "Generates a new SasToken from the SAS Token string returned by the configured generator_fn callable" + ) + async def test_returns_token(self, mocker, sastoken_generator): + if isinstance(sastoken_generator.generator_fn, mocker.AsyncMock): + assert sastoken_generator.generator_fn.await_count == 0 + else: + assert sastoken_generator.generator_fn.call_count == 0 + + token = await sastoken_generator.generate_sastoken() + assert isinstance(token, SasToken) + + if isinstance(sastoken_generator.generator_fn, mocker.AsyncMock): + assert sastoken_generator.generator_fn.await_count == 1 + assert sastoken_generator.generator_fn.await_args == mocker.call() + else: + assert sastoken_generator.generator_fn.call_count == 1 + assert sastoken_generator.generator_fn.call_args == mocker.call() + + assert str(token) == sastoken_generator.generator_fn.return_value + + @pytest.mark.it( + "Raises SasTokenError if an exception is raised while trying to generate a SAS Token string with the generator_fn" + ) + async def test_generator_fn_raises(self, sastoken_generator, arbitrary_exception): + sastoken_generator.generator_fn.side_effect = arbitrary_exception + + with pytest.raises(SasTokenError) as e_info: + await sastoken_generator.generate_sastoken() + assert e_info.value.__cause__ is arbitrary_exception + + @pytest.mark.it("Raises SasTokenError if the generated SAS Token string is invalid") + async def test_invalid_token(self, sastoken_generator): + sastoken_generator.generator_fn.return_value = "not a sastoken" + + with pytest.raises(SasTokenError) as e_info: + await sastoken_generator.generate_sastoken() + assert isinstance(e_info.value.__cause__, ValueError) + + +@pytest.mark.describe("SasTokenProvider -- Instantiation") +class TestSasTokenProviderInstantiation: + @pytest.mark.it("Stores the provided SasTokenGenerator") + async def test_generator_fn(self, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + assert provider._generator is sastoken_generator + + @pytest.mark.it("Sets the token update margin to the DEFAULT_TOKEN_UPDATE_MARGIN") + async def test_token_update_margin(self, sastoken, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + assert provider._token_update_margin == DEFAULT_TOKEN_UPDATE_MARGIN + + @pytest.mark.it("Sets the current token to None") + async def test_current_token(self, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + assert provider._current_token is None + + @pytest.mark.it("Sets the 'keep token fresh' background task attribute to None") + async def test_background_task(self, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + assert provider._keep_token_fresh_bg_task is None + + +@pytest.mark.describe("SasTokenProvider - .start()") +class TestSasTokenProviderStart: + @pytest.mark.it( + "Generates a new SasToken using the stored SasTokenGenerator and sets it as the current token" + ) + async def test_generates_current_token(self, mocker, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + assert sastoken_generator.generate_sastoken.await_count == 0 + + await provider.start() + + assert sastoken_generator.generate_sastoken.await_count == 1 + assert sastoken_generator.generate_sastoken.await_args == mocker.call() + assert isinstance(provider._current_token, SasToken) + assert provider._current_token == sastoken_generator.generate_sastoken.spy_return + + # Cleanup + await provider.stop() + + @pytest.mark.it("Sends notification of new token availability") + async def test_notify(self, mocker, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + notification_spy = mocker.spy(provider._new_sastoken_available, "notify_all") + assert notification_spy.call_count == 0 + + await provider.start() + + assert notification_spy.call_count == 1 + + @pytest.mark.it("Allows any exception raised while trying to generate a SasToken to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(SasTokenError("token error"), id="SasTokenError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_generation_raises(self, sastoken_generator, exception): + sastoken_generator.generate_sastoken.side_effect = exception + provider = SasTokenProvider(sastoken_generator) + + with pytest.raises(type(exception)) as e_info: + await provider.start() + assert e_info.value is exception + + @pytest.mark.it("Raises a SasTokenError if the generated SAS Token string has already expired") + async def test_expired_token(self, mocker, sastoken_generator): + expired_token_str = TOKEN_FORMAT.format( + resource=urllib.parse.quote(FAKE_URI, safe=""), + signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""), + expiry=int(time.time()) - 3600, # 1 hour ago + ) + sastoken_generator.generate_sastoken = mocker.AsyncMock() + sastoken_generator.generate_sastoken.return_value = SasToken(expired_token_str) + provider = SasTokenProvider(sastoken_generator) + + with pytest.raises(SasTokenError): + await provider.start() + + # NOTE: The contents of this coroutine are tested in a separate test suite below. + # See TestSasTokenProviderKeepTokenFresh for more. + @pytest.mark.it("Begins running the ._keep_token_fresh() coroutine method, storing the task") + async def test_keep_token_fresh_running(self, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + assert provider._keep_token_fresh_bg_task is None + + await provider.start() + + assert isinstance(provider._keep_token_fresh_bg_task, asyncio.Task) + assert not provider._keep_token_fresh_bg_task.done() + if sys.version_info >= (3, 8): + # NOTE: There isn't a way to validate the contents of a task until 3.8 + # as far as I can tell. + task_coro = provider._keep_token_fresh_bg_task.get_coro() + assert task_coro.__qualname__ == "SasTokenProvider._keep_token_fresh" + + # Cleanup + await provider.stop() + + @pytest.mark.it("Does nothing if already started") + async def test_already_started(self, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + + # Start + await provider.start() + + # Expected state + assert isinstance(provider._keep_token_fresh_bg_task, asyncio.Task) + assert not provider._keep_token_fresh_bg_task.done() + current_keep_token_fresh_bg_task = provider._keep_token_fresh_bg_task + assert sastoken_generator.generate_sastoken.await_count == 1 + + # Start again + await provider.start() + + # No changes + assert provider._keep_token_fresh_bg_task is current_keep_token_fresh_bg_task + assert not provider._keep_token_fresh_bg_task.done() + assert sastoken_generator.generate_sastoken.await_count == 1 + + +@pytest.mark.describe("SasTokenProvider - .stop()") +class TestSasTokenProviderShutdown: + @pytest.mark.it("Cancels the stored ._keep_token_fresh() task and removes it, if it exists") + async def test_cancels_keep_token_fresh(self, sastoken_provider): + t = sastoken_provider._keep_token_fresh_bg_task + assert isinstance(t, asyncio.Task) + assert not t.done() + + await sastoken_provider.stop() + + assert t.done() + assert t.cancelled() + assert sastoken_provider._keep_token_fresh_bg_task is None + + @pytest.mark.it("Sets the current token back to None") + async def test_current_token(self, sastoken_provider): + assert sastoken_provider._current_token is not None + + await sastoken_provider.stop() + + assert sastoken_provider._current_token is None + + @pytest.mark.it("Does nothing if already stopped") + async def test_already_stopped(self, sastoken_provider): + # Currently running + t = sastoken_provider._keep_token_fresh_bg_task + assert not t.done() + + # Stop + await sastoken_provider.stop() + + # Expected state + assert t.done() + assert t.cancelled() + assert sastoken_provider._keep_token_fresh_bg_task is None + assert sastoken_provider._current_token is None + + # Stop again + await sastoken_provider.stop() + + # No changes + assert sastoken_provider._keep_token_fresh_bg_task is None + assert sastoken_provider._current_token is None + + +@pytest.mark.describe("SasTokenProvider - .get_current_sastoken()") +class TestSasTokenGetCurrentSasToken: + @pytest.mark.it("Returns the current SasToken object, if running") + def test_returns_current_token(self, sastoken_provider): + assert sastoken_provider._keep_token_fresh_bg_task is not None + current_token = sastoken_provider.get_current_sastoken() + assert current_token is sastoken_provider._current_token + new_token_str = TOKEN_FORMAT.format( + resource=urllib.parse.quote(FAKE_URI, safe=""), + signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""), + expiry=int(time.time()) + 3600, + ) + new_token = SasToken(new_token_str) + sastoken_provider._current_token = new_token + assert sastoken_provider.get_current_sastoken() is new_token + + @pytest.mark.it("Raises RuntimeError if not running (i.e. not started)") + async def test_not_running(self, sastoken_generator): + provider = SasTokenProvider(sastoken_generator) + assert provider._keep_token_fresh_bg_task is None + with pytest.raises(RuntimeError): + provider.get_current_sastoken() + + +@pytest.mark.describe("SasTokenProvider - .wait_for_new_sastoken()") +class TestSasTokenWaitForNewSasToken: + @pytest.mark.it( + "Returns the current SasToken object once a notified of a new token being available" + ) + async def test_returns_new_current_token(self, sastoken_provider): + token_str_1 = TOKEN_FORMAT.format( + resource=urllib.parse.quote(FAKE_URI, safe=""), + signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""), + expiry=int(time.time()) + 3600, + ) + token1 = SasToken(token_str_1) + token_str_2 = TOKEN_FORMAT.format( + resource=urllib.parse.quote(FAKE_URI, safe=""), + signature=urllib.parse.quote(FAKE_SIGNED_DATA2, safe=""), + expiry=int(time.time()) + 4500, + ) + token2 = SasToken(token_str_2) + + sastoken_provider._current_token = token1 + assert sastoken_provider.get_current_sastoken() is token1 + + # Waiting for new token, but one is not yet available + task = asyncio.create_task(sastoken_provider.wait_for_new_sastoken()) + await asyncio.sleep(0.1) + assert not task.done() + + # Update the token, but without notification, the waiting task still does not return + sastoken_provider._current_token = token2 + await asyncio.sleep(0.1) + assert not task.done() + + # Notify that a new token is available, and now the task will return + async with sastoken_provider._new_sastoken_available: + sastoken_provider._new_sastoken_available.notify_all() + returned_token = await task + + # The task returned the new token + assert returned_token is token2 + assert returned_token is not token1 + assert returned_token is sastoken_provider.get_current_sastoken() + + +# NOTE: This test suite assumes the correct implementation of ._wait_until() for critical +# requirements. Find it tested in a separate suite below (TestWaitUntil) +@pytest.mark.describe("SasTokenProvider - BG TASK: ._keep_token_fresh") +class TestSasTokenProviderKeepTokenFresh: + @pytest.fixture(autouse=True) + def spy_time(self, mocker): + """Spy on the time module so that we can find out last time that was returned""" + spy_time = mocker.spy(time, "time") + return spy_time + + # NOTE: This is an autouse fixture to ensure that it gets called first, since we want to make sure + # this mock is running when the SasTokenProvider is created. + @pytest.fixture(autouse=True) + def mock_wait_until(self, mocker): + """Mock out the wait_until function so these tests aren't dependent on real time passing""" + mock_wait_until = mocker.patch.object(st, "_wait_until") + mock_wait_until._allow_proceed = asyncio.Event() + + # Fake implementation that will wait for an explicit trigger to proceed, rather than the + # passage of time + async def fake_wait_until(when): + await mock_wait_until._allow_proceed.wait() + + mock_wait_until.side_effect = fake_wait_until + + # Define a mechanism that will allow an explicit trigger to let the mocked coroutine return + def proceed(): + mock_wait_until._allow_proceed.set() + mock_wait_until._allow_proceed = asyncio.Event() + + mock_wait_until.proceed = proceed + + return mock_wait_until + + @pytest.mark.it( + "Waits until the configured update margin number of seconds before current SasToken expiry to generate a new SasToken" + ) + async def test_wait_to_generate(self, mocker, mock_wait_until, sastoken_provider): + original_token = sastoken_provider.get_current_sastoken() + assert sastoken_provider._generator.generate_sastoken.await_count == 0 + await asyncio.sleep(0.1) + # We are waiting the expected amount of time + expected_update_time = original_token.expiry_time - sastoken_provider._token_update_margin + assert mock_wait_until.await_count == 1 + assert mock_wait_until.await_args == mocker.call(expected_update_time) + # Allow the waiting to end, and a new token to be generated + mock_wait_until.proceed() + await asyncio.sleep(0.1) + assert sastoken_provider._generator.generate_sastoken.await_count == 1 + assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call() + + @pytest.mark.it( + "Sets the newly generated SasToken as the new current SasToken and sends notification of its availability" + ) + async def test_replace_token_and_notify(self, mocker, sastoken_provider, mock_wait_until): + notification_spy = mocker.spy(sastoken_provider._new_sastoken_available, "notify_all") + # We have the original token, as we have not yet generated a new one + original_token = sastoken_provider.get_current_sastoken() + assert sastoken_provider._generator.generate_sastoken.await_count == 0 + assert notification_spy.call_count == 0 + # Allow waiting to proceed, and a new token to be generated + mock_wait_until.proceed() + await asyncio.sleep(0.1) + # A new token has now been generated + assert sastoken_provider._generator.generate_sastoken.await_count == 1 + # The current token is now the token that was just generated + current_token = sastoken_provider.get_current_sastoken() + assert current_token is sastoken_provider._generator.generate_sastoken.spy_return + # This token is not the same as the original token + assert current_token is not original_token + # A notification was sent about the new token + assert notification_spy.call_count == 1 + + @pytest.mark.it( + "Waits until the configured update margin number of seconds before the NEW current SasToken expiry, after each time a new SasToken is generated, before once again generating a new SasToken" + ) + async def test_wait_to_generate_again_and_again( + self, mocker, mock_wait_until, sastoken_provider + ): + # Current token is the original, we have not yet generated a new one + original_token = sastoken_provider.get_current_sastoken() + assert sastoken_provider._generator.generate_sastoken.await_count == 0 + await asyncio.sleep(0.1) + # We are waiting based on the original token's expiry time + expected_update_time = original_token.expiry_time - sastoken_provider._token_update_margin + assert mock_wait_until.await_count == 1 + assert mock_wait_until.await_args == mocker.call(expected_update_time) + # Allow the waiting to end, and a new token to be generated + mock_wait_until.proceed() + await asyncio.sleep(0.1) + assert sastoken_provider._generator.generate_sastoken.await_count == 1 + assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call() + # New token is the one that was just generated + new_token = sastoken_provider.get_current_sastoken() + assert new_token is sastoken_provider._generator.generate_sastoken.spy_return + assert new_token is not original_token + # We are once again waiting, this time based on the new token's expiry time + expected_update_time = new_token.expiry_time - sastoken_provider._token_update_margin + assert mock_wait_until.await_count == 2 + assert mock_wait_until.await_args == mocker.call(expected_update_time) + # Allow the waiting to end and another new token to be generated + mock_wait_until.proceed() + await asyncio.sleep(0.1) + assert sastoken_provider._generator.generate_sastoken.await_count == 2 + assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call() + # Newest token is the one that was just generated + newest_token = sastoken_provider.get_current_sastoken() + assert newest_token is sastoken_provider._generator.generate_sastoken.spy_return + assert newest_token is not original_token + assert newest_token is not new_token + # We are once again waiting, this time based on the newest token's expiry time + expected_update_time = newest_token.expiry_time - sastoken_provider._token_update_margin + assert mock_wait_until.await_count == 3 + assert mock_wait_until.await_args == mocker.call(expected_update_time) + # And so on and so forth to infinity... + + @pytest.mark.it( + "Sets the newly generated SasToken as the new current SasToken and sends notification of its availability each time a new token is generated" + ) + async def test_replace_token_and_notify_each_time( + self, mocker, sastoken_provider, mock_wait_until + ): + notification_spy = mocker.spy(sastoken_provider._new_sastoken_available, "notify_all") + # We have the original token, as we have not yet generated a new one + original_token = sastoken_provider.get_current_sastoken() + assert sastoken_provider._generator.generate_sastoken.await_count == 0 + assert notification_spy.call_count == 0 + # Allow waiting to proceed, and a new token to be generated + mock_wait_until.proceed() + await asyncio.sleep(0.1) + # A new token has now been generated + assert sastoken_provider._generator.generate_sastoken.await_count == 1 + # The current token is now the token that was just generated + second_token = sastoken_provider.get_current_sastoken() + assert second_token is sastoken_provider._generator.generate_sastoken.spy_return + # This token is not the same as the original token + assert second_token is not original_token + # A notification was sent about the new token + assert notification_spy.call_count == 1 + # Allow waiting to proceed and another new token to be generated + mock_wait_until.proceed() + await asyncio.sleep(0.1) + # Another new token has now been generated + assert sastoken_provider._generator.generate_sastoken.await_count == 2 + # The current token is now the token that was just generated + third_token = sastoken_provider.get_current_sastoken() + assert third_token is sastoken_provider._generator.generate_sastoken.spy_return + # This token is not the same as any previous token + assert third_token is not original_token + assert third_token is not second_token + # A notification was sent about the new token + assert notification_spy.call_count == 2 + # And so on and so forth to infinity... + + @pytest.mark.it("Tries to generate again in 10 seconds if SasToken generation fails") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(SasTokenError("Some error in SAS"), id="SasTokenError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_generation_failure( + self, mocker, sastoken_provider, exception, mock_wait_until, spy_time + ): + # Set token generation to raise exception + sastoken_provider._generator.generate_sastoken.side_effect = exception + # Token generation has not yet happened + assert sastoken_provider._generator.generate_sastoken.await_count == 0 + # Allow waiting to proceed, and a new token to be generated + assert mock_wait_until.await_count == 1 + mock_wait_until.proceed() + await asyncio.sleep(0.1) + # Waits 10 seconds past the current time + expected_generate_time = spy_time.spy_return + 10 + assert mock_wait_until.await_count == 2 + assert mock_wait_until.await_args == mocker.call(expected_generate_time) + + +# NOTE: We don't normally test convention-private helpers directly, but in this case, the +# complexity is high enough, and the function is critical enough, that it makes more sense +# to isolate rather than attempting to indirectly test. +@pytest.mark.describe("._wait_until()") +class TestWaitUntil: + @pytest.mark.it( + "Repeatedly does 1 second asyncio sleeps until the current time is greater than the provided 'when' parameter" + ) + @pytest.mark.parametrize( + "time_from_now", + [ + pytest.param(5, id="5 seconds from now"), + pytest.param(60, id="1 minute from now"), + pytest.param(3600, id="1 hour from now"), + ], + ) + async def test_sleep(self, mocker, time_from_now): + # Mock out the sleep coroutine so that we aren't waiting around forever on this test + mock_sleep = mocker.patch.object(asyncio, "sleep") + + # mock out time + def fake_time(): + """Fake time implementation that will return a time float that is 1 larger + than the previous time it was called""" + fake_time_return = FAKE_CURRENT_TIME + while True: + yield fake_time_return + fake_time_return += 1 + + fake_time_gen = fake_time() + mock_time = mocker.patch.object(time, "time", side_effect=fake_time_gen) + + desired_time = FAKE_CURRENT_TIME + time_from_now + + await st._wait_until(desired_time) + + assert mock_sleep.await_count == time_from_now + for call in mock_sleep.await_args_list: + assert call == mocker.call(1) + assert mock_time.call_count == time_from_now + 1 + for call in mock_time.call_args_list: + assert call == mocker.call() diff --git a/tests/unit/common/auth/test_signing_mechanism.py b/tests/unit/test_signing_mechanism.py similarity index 88% rename from tests/unit/common/auth/test_signing_mechanism.py rename to tests/unit/test_signing_mechanism.py index f0f8a6b76..c63fde562 100644 --- a/tests/unit/common/auth/test_signing_mechanism.py +++ b/tests/unit/test_signing_mechanism.py @@ -5,13 +5,10 @@ # -------------------------------------------------------------------------- import pytest -import logging import hmac import hashlib import base64 -from azure.iot.device.common.auth import SymmetricKeySigningMechanism - -logging.basicConfig(level=logging.DEBUG) +from azure.iot.device.signing_mechanism import SymmetricKeySigningMechanism @pytest.mark.describe("SymmetricKeySigningMechanism - Instantiation") @@ -77,13 +74,13 @@ def signing_mechanism(self): @pytest.mark.it( "Generates an HMAC message digest from the signing key and provided data string, using the HMAC-SHA256 algorithm" ) - def test_hmac(self, mocker, signing_mechanism): + async def test_hmac(self, mocker, signing_mechanism): hmac_mock = mocker.patch.object(hmac, "HMAC") hmac_digest_mock = hmac_mock.return_value.digest hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z" data_string = "sign this message" - signing_mechanism.sign(data_string) + await signing_mechanism.sign(data_string) assert hmac_mock.call_count == 1 assert hmac_mock.call_args == mocker.call( @@ -96,13 +93,13 @@ def test_hmac(self, mocker, signing_mechanism): @pytest.mark.it( "Returns the base64 encoded HMAC message digest (converted to string) as the signed data" ) - def test_b64encode(self, mocker, signing_mechanism): + async def test_b64encode(self, mocker, signing_mechanism): hmac_mock = mocker.patch.object(hmac, "HMAC") hmac_digest_mock = hmac_mock.return_value.digest hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z" data_string = "sign this message" - signature = signing_mechanism.sign(data_string) + signature = await signing_mechanism.sign(data_string) assert signature == base64.b64encode(hmac_digest_mock.return_value).decode("utf-8") @@ -118,11 +115,11 @@ def test_b64encode(self, mocker, signing_mechanism): ), ], ) - def test_supported_types(self, signing_mechanism, data_string, expected_signature): - assert signing_mechanism.sign(data_string) == expected_signature + async def test_supported_types(self, signing_mechanism, data_string, expected_signature): + assert await signing_mechanism.sign(data_string) == expected_signature @pytest.mark.it("Raises a ValueError if unable to sign the provided data string") @pytest.mark.parametrize("data_string", [pytest.param(123, id="Integer input")]) - def test_bad_input(self, signing_mechanism, data_string): + async def test_bad_input(self, signing_mechanism, data_string): with pytest.raises(ValueError): - signing_mechanism.sign(data_string) + await signing_mechanism.sign(data_string) diff --git a/tests/unit/test_user_agent.py b/tests/unit/test_user_agent.py index e005a74d9..770700bd9 100644 --- a/tests/unit/test_user_agent.py +++ b/tests/unit/test_user_agent.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- import pytest -from azure.iot.device import user_agent import platform +from azure.iot.device import user_agent from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER @@ -15,7 +15,7 @@ @pytest.mark.describe(".get_iothub_user_agent()") -class TestGetIothubUserAgent(object): +class TestGetIothubUserAgent: @pytest.mark.it( "Returns a user agent string formatted for IoTHub, containing python version, operating system and architecture of the system" ) @@ -40,7 +40,7 @@ def test_get_iothub_user_agent(self): @pytest.mark.describe(".get_provisioning_user_agent()") -class TestGetProvisioningUserAgent(object): +class TestGetProvisioningUserAgent: @pytest.mark.it( "Returns a user agent string formatted for the Provisioning Service, containing python version, operating system and architecture of the system" ) diff --git a/versions.json b/versions.json new file mode 100644 index 000000000..a15c37f4b --- /dev/null +++ b/versions.json @@ -0,0 +1,9 @@ +[ + { + "version" : "2.12.0" + }, + { + "version" : "3.0.0b2", + "eol" : "2023-07-15" + } +] diff --git a/vsts/build.yaml b/vsts/build.yaml index 48f263e84..f5f06d295 100644 --- a/vsts/build.yaml +++ b/vsts/build.yaml @@ -29,8 +29,6 @@ jobs: vmImage: 'Ubuntu 20.04' strategy: matrix: - Python36: - python.version: '3.6' Python37: python.version: '3.7' Python38: @@ -39,6 +37,8 @@ jobs: python.version: '3.9' Python310: python.version: '3.10' + Python311: + python.version: '3.11' steps: - task: UsePythonVersion@0 displayName: 'Use Python $(python.version)' diff --git a/vsts/dps-e2e.yaml b/vsts/dps-e2e.yaml index 1ec9e178c..547bc5d7f 100644 --- a/vsts/dps-e2e.yaml +++ b/vsts/dps-e2e.yaml @@ -31,19 +31,19 @@ jobs: displayName: 'Run Specified E2E Test with env variables' env: - IOTHUB_CONNECTION_STRING: $(PYTHONSDK2-LINUX-IOTHUB-CONNECTION-STRING) - IOTHUB_EVENTHUB_CONNECTION_STRING: $(PYTHONSDK2-LINUX-IOTHUB-EVENTHUB-CONNECTION-STRING) - IOTHUB_CA_ROOT_CERT: $(PYTHONSDK2-LINUX-IOTHUB-CA-ROOT-CERT) - IOTHUB_CA_ROOT_CERT_KEY: $(PYTHONSDK2-LINUX-IOTHUB-CA-ROOT-CERT-KEY) - STORAGE_CONNECTION_STRING: $(PYTHONSDK2-LINUX-STORAGE-CONNECTION-STRING) - - PROVISIONING_DEVICE_ENDPOINT: $(PYTHONSDK2-LINUX-DPS-DEVICE-ENDPOINT) - PROVISIONING_SERVICE_CONNECTION_STRING: $(PYTHONSDK2-LINUX-DPS-CONNECTION-STRING) - PROVISIONING_DEVICE_IDSCOPE: $(PYTHONSDK2-LINUX-DPS-ID-SCOPE) - - PROVISIONING_ROOT_CERT: $(PYTHONSDK2-LINUX-IOT-PROVISIONING-ROOT-CERT) - PROVISIONING_ROOT_CERT_KEY: $(PYTHONSDK2-LINUX-IOT-PROVISIONING-ROOT-CERT-KEY) - PROVISIONING_ROOT_PASSWORD: $(PYTHONSDK2-LINUX-ROOT-CERT-PASSWORD) + IOTHUB_CONNECTION_STRING: $(PYTHONOCT22-MAC-IOTHUB-CONNECTION-STRING) + IOTHUB_EVENTHUB_CONNECTION_STRING: $(PYTHONOCT22-MAC-IOTHUB-EVENTHUB-CONNECTION-STRING) + IOTHUB_CA_ROOT_CERT: $(PYTHONOCT22-MAC-IOTHUB-CA-ROOT-CERT) + IOTHUB_CA_ROOT_CERT_KEY: $(PYTHONOCT22-MAC-IOTHUB-CA-ROOT-CERT-KEY) + STORAGE_CONNECTION_STRING: $(PYTHONOCT22-MAC-STORAGE-CONNECTION-STRING) + + PROVISIONING_DEVICE_ENDPOINT: $(PYTHONOCT22-MAC-DPS-DEVICE-ENDPOINT) + PROVISIONING_SERVICE_CONNECTION_STRING: $(PYTHONOCT22-MAC-DPS-CONNECTION-STRING) + PROVISIONING_DEVICE_IDSCOPE: $(PYTHONOCT22-MAC-DPS-ID-SCOPE) + + PROVISIONING_ROOT_CERT: $(PYTHONOCT22-MAC-IOT-PROVISIONING-ROOT-CERT) + PROVISIONING_ROOT_CERT_KEY: $(PYTHONOCT22-MAC-IOT-PROVISIONING-ROOT-CERT-KEY) + PROVISIONING_ROOT_PASSWORD: $(PYTHONOCT22-MAC-ROOT-CERT-PASSWORD) PYTHONUNBUFFERED: True - task: PublishTestResults@2 diff --git a/vsts/python-canary.yaml b/vsts/python-canary.yaml index 8378006e8..2e7f05bdf 100644 --- a/vsts/python-canary.yaml +++ b/vsts/python-canary.yaml @@ -16,25 +16,25 @@ jobs: transport: 'mqttws' imageName: 'windows-latest' consumerGroup: 'cg2' - py36_linux_mqtt: - pv: '3.6' - transport: 'mqtt' - imageName: 'Ubuntu 20.04' - consumerGroup: 'cg3' py37_linux_mqttws: pv: '3.7' transport: 'mqttws' imageName: 'Ubuntu 20.04' - consumerGroup: 'cg4' + consumerGroup: 'cg3' py38_linux_mqtt: pv: '3.8' transport: 'mqtt' imageName: 'Ubuntu 20.04' - consumerGroup: 'cg5' + consumerGroup: 'cg4' py310_linux_mqtt: pv: '3.10' transport: 'mqtt' imageName: 'Ubuntu 20.04' + consumerGroup: 'cg5' + py311_linux_mqtt: + pv: '3.11' + transport: 'mqtt' + imageName: 'Ubuntu 20.04' consumerGroup: 'cg6' pool: diff --git a/vsts/python-e2e.yaml b/vsts/python-e2e.yaml index 4af63f148..85e56d32e 100644 --- a/vsts/python-e2e.yaml +++ b/vsts/python-e2e.yaml @@ -40,7 +40,7 @@ jobs: strategy: matrix: py310_mqtt: { pv: '3.10', transport: 'mqtt', consumer_group: 'e2e-consumer-group-3' } - py36_mqtt_ws: { pv: '3.6', transport: 'mqttws', consumer_group: 'e2e-consumer-group-4' } + py37_mqtt_ws: { pv: '3.7', transport: 'mqttws', consumer_group: 'e2e-consumer-group-4' } steps: - task: UsePythonVersion@0 diff --git a/vsts/python-edge-e2e.yaml b/vsts/python-edge-e2e.yaml new file mode 100644 index 000000000..1d16bb035 --- /dev/null +++ b/vsts/python-edge-e2e.yaml @@ -0,0 +1,109 @@ +name: $(BuildID)_$(BuildDefinitionName)_$(SourceBranchName) + +variables: + hubName: 'iotsdk-python-horton-hub' + runCount: $[counter(0,100)] + containerRegistryShortName: iotsdke2e + containerRegistry: iotsdke2e.azurecr.io + scriptDirectory: $(Build.SourcesDirectory)/scripts/edge_setup + +jobs: +- job: 'Test' + + strategy: + maxParallel: 4 + matrix: + linux_edge_py310_mqtt: + languageVersion: 'py310' + transport: 'mqtt' + imageName: 'Ubuntu 20.04' + consumerGroup: 'cg6' + + pool: + vmImage: $(imageName) + + variables: + deviceId: $(languageVersion)-$(runCount) + dockerImageTag: $(runCount) + testModImage: "python-$(languageVersion):$(runCount)" + echoModImage: "echomod:latest" + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.10' + architecture: 'x64' + + - task: AzureCLI@2 + displayName: Install Prerequisites + inputs: + azureSubscription: 'ServiceConnectionDemo' + scriptType: bash + scriptLocation: scriptPath + scriptPath: $(scriptDirectory)/install/install-prereqs.sh + + + - task: AzureCLI@2 + displayName: Create and push docker images + inputs: + azureSubscription: 'ServiceConnectionDemo' + scriptType: bash + scriptLocation: inlineScript + inlineScript: | + az acr login -n $(containerRegistry) + cd $(scriptDirectory)/docker-build/ + ./build-echomod-container.sh $(containerRegistry)/$(echoModImage) + ./build-test-containers.sh Dockerfile.$(languageVersion) $(containerRegistry)/$(testModImage) + + - task: AzureCLI@2 + displayName: Create and configure edge device ${{ variables.deviceId }} + inputs: + azureSubscription: 'ServiceConnectionDemo' + scriptType: bash + scriptLocation: inlineScript + inlineScript: | + cd $(scriptDirectory)/deploy/ + export IOTHUB_E2E_REPO_ADDRESS=$(IOTHUB-E2E-REPO-ADDRESS) + export IOTHUB_E2E_REPO_USER=$(IOTHUB-E2E-REPO-USER) + export IOTHUB_E2E_REPO_PASSWORD=$(IOTHUB-E2E-REPO-PASSWORD) + ./create-edge-device.sh $(hubName) $(deviceId) + ./deploy-edge-modules.sh $(hubName) $(deviceId) $(containerRegistry)/$(testModImage) $(containerRegistry)/$(echoModImage) + + - task: Bash@3 + displayName: "Wait for container to start" + timeoutInMinutes: 15 + inputs: + targetType: script + script: | + cd $(ScriptDirectory) + ./wait-for-container.sh edgeAgent + ./wait-for-container.sh edgeHub + ./wait-for-container.sh testMod + ./wait-for-container.sh echoMod + docker ps + + - task: AzureCLI@2 + displayName: Remove Edge Device ${{ variables.deviceId }} + inputs: + azureSubscription: 'ServiceConnectionDemo' + scriptType: bash + scriptLocation: inlineScript + inlineScript: | + az iot hub device-identity delete -n $(hubName) --device-id $(deviceId) + condition: always() + + - task: AzureCLI@2 + displayName: Remove ${{ variables.testModImage }} container + inputs: + azureSubscription: 'ServiceConnectionDemo' + scriptType: bash + scriptLocation: inlineScript + inlineScript: | + az acr login -n $(containerRegistryShortName) + az acr repository delete -n $(containerRegistryShortName) -t $(testModImage) --yes + condition: always() + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + condition: always() + diff --git a/vsts/python-nightly.yaml b/vsts/python-nightly.yaml index 86f7467ad..eb820b261 100644 --- a/vsts/python-nightly.yaml +++ b/vsts/python-nightly.yaml @@ -17,30 +17,30 @@ jobs: imageName: 'windows-latest' consumerGroup: 'cg2' - py36_linux_mqtt: - pv: '3.6' - transport: 'mqtt' - imageName: 'Ubuntu 20.04' - consumerGroup: 'cg3' py37_linux_mqttws: pv: '3.7' transport: 'mqttws' imageName: 'Ubuntu 20.04' - consumerGroup: 'cg4' + consumerGroup: 'cg3' py38_linux_mqtt: pv: '3.8' transport: 'mqtt' imageName: 'Ubuntu 20.04' - consumerGroup: 'cg5' + consumerGroup: 'cg4' py39_linux_mqttws: pv: '3.9' transport: 'mqttws' imageName: 'Ubuntu 20.04' - consumerGroup: 'cg6' + consumerGroup: 'cg5' py310_linux_mqtt: pv: '3.10' transport: 'mqtt' imageName: 'Ubuntu 20.04' + consumerGroup: 'cg6' + py311_linux_mqtt: + pv: '3.11' + transport: 'mqtt' + imageName: 'Ubuntu 20.04' consumerGroup: 'cg7' pool: