diff --git a/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client.py b/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client.py index 1b526d505da2..5d338689908d 100644 --- a/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client.py +++ b/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client.py @@ -10,9 +10,10 @@ ) import logging try: - from urllib.parse import parse_qs + from urllib.parse import parse_qs, quote except ImportError: from urlparse import parse_qs # type: ignore + from urllib2 import quote # type: ignore import six @@ -156,7 +157,7 @@ def _create_pipeline(self, credential, **kwargs): config = kwargs.get('_configuration') or create_configuration(**kwargs) if kwargs.get('_pipeline'): return config, kwargs['_pipeline'] - config.transport = kwargs.get('transport') # type: HttpTransport + config.transport = kwargs.get('transport') # type: ignore if not config.transport: config.transport = RequestsTransport(config) policies = [ @@ -276,7 +277,7 @@ def create_configuration(**kwargs): def parse_query(query_str): sas_values = QueryStringConstants.to_list() parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()} - sas_params = ["{}={}".format(k, v) for k, v in parsed_query.items() if k in sas_values] + sas_params = ["{}={}".format(k, quote(v)) for k, v in parsed_query.items() if k in sas_values] sas_token = None if sas_params: sas_token = '&'.join(sas_params) diff --git a/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client_async.py b/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client_async.py index 2e15dd9e8813..19d314f87892 100644 --- a/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client_async.py @@ -30,7 +30,9 @@ QueueMessagePolicy) from .policies_async import AsyncStorageResponseHook - +if TYPE_CHECKING: + from azure.core.pipeline import Pipeline + from azure.core import Configuration _LOGGER = logging.getLogger(__name__) @@ -60,11 +62,11 @@ def _create_pipeline(self, credential, **kwargs): raise TypeError("Unsupported credential: {}".format(credential)) if 'connection_timeout' not in kwargs: - kwargs['connection_timeout'] = DEFAULT_SOCKET_TIMEOUT[0] + kwargs['connection_timeout'] = DEFAULT_SOCKET_TIMEOUT[0] # type: ignore config = kwargs.get('_configuration') or create_configuration(**kwargs) if kwargs.get('_pipeline'): return config, kwargs['_pipeline'] - config.transport = kwargs.get('transport') # type: HttpTransport + config.transport = kwargs.get('transport') # type: ignore if not config.transport: config.transport = AsyncTransport(config) policies = [ @@ -76,7 +78,7 @@ def _create_pipeline(self, credential, **kwargs): credential_policy, ContentDecodePolicy(), AsyncRedirectPolicy(**kwargs), - StorageHosts(hosts=self._hosts, **kwargs), + StorageHosts(hosts=self._hosts, **kwargs), # type: ignore config.retry_policy, config.logging_policy, AsyncStorageResponseHook(**kwargs), diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py index e027028eee3e..8cd24874fff0 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py @@ -14,7 +14,7 @@ AccountPermissions, StorageErrorCode ) -from ._queue_utils import ( +from ._message_encoding import ( TextBase64EncodePolicy, TextBase64DecodePolicy, BinaryBase64EncodePolicy, diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py new file mode 100644 index 000000000000..3137faa0aaa4 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=unused-argument + +from azure.core.exceptions import ResourceExistsError + +from ._shared.models import StorageErrorCode +from .models import QueueProperties + + +def deserialize_metadata(response, obj, headers): + raw_metadata = {k: v for k, v in response.headers.items() if k.startswith("x-ms-meta-")} + return {k[10:]: v for k, v in raw_metadata.items()} + + +def deserialize_queue_properties(response, obj, headers): + metadata = deserialize_metadata(response, obj, headers) + queue_properties = QueueProperties( + metadata=metadata, + **headers + ) + return queue_properties + + +def deserialize_queue_creation(response, obj, headers): + if response.status_code == 204: + error_code = StorageErrorCode.queue_already_exists + error = ResourceExistsError( + message="Queue already exists\nRequestId:{}\nTime:{}\nErrorCode:{}".format( + headers['x-ms-request-id'], + headers['Date'], + error_code + ), + response=response) + error.error_code = error_code + error.additional_info = {} + raise error + return headers diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_utils.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py similarity index 83% rename from sdk/storage/azure-storage-queue/azure/storage/queue/_queue_utils.py rename to sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 74d433e0de56..3e7e64fcd4dd 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_utils.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -10,41 +10,9 @@ from xml.sax.saxutils import unescape as xml_unescape import six -from azure.core.exceptions import ResourceExistsError, DecodeError +from azure.core.exceptions import DecodeError -from ._shared.models import StorageErrorCode from ._shared.encryption import decrypt_queue_message, encrypt_queue_message -from .models import QueueProperties - - -def deserialize_metadata(response, obj, headers): - raw_metadata = {k: v for k, v in response.headers.items() if k.startswith("x-ms-meta-")} - return {k[10:]: v for k, v in raw_metadata.items()} - - -def deserialize_queue_properties(response, obj, headers): - metadata = deserialize_metadata(response, obj, headers) - queue_properties = QueueProperties( - metadata=metadata, - **headers - ) - return queue_properties - - -def deserialize_queue_creation(response, obj, headers): - if response.status_code == 204: - error_code = StorageErrorCode.queue_already_exists - error = ResourceExistsError( - message="Queue already exists\nRequestId:{}\nTime:{}\nErrorCode:{}".format( - headers['x-ms-request-id'], - headers['Date'], - error_code - ), - response=response) - error.error_code = error_code - error.additional_info = {} - raise error - return headers class MessageEncodePolicy(object): @@ -104,7 +72,7 @@ def decode(self, content, response): class TextBase64EncodePolicy(MessageEncodePolicy): """Base 64 message encoding policy for text messages. - + Encodes text (unicode) messages to base 64. If the input content is not text, a TypeError will be raised. Input text must support UTF-8. """ @@ -117,7 +85,7 @@ def encode(self, content): class TextBase64DecodePolicy(MessageDecodePolicy): """Message decoding policy for base 64-encoded messages into text. - + Decodes base64-encoded messages to text (unicode). If the input content is not valid base 64, a DecodeError will be raised. Message data must support UTF-8. @@ -136,7 +104,7 @@ def decode(self, content, response): class BinaryBase64EncodePolicy(MessageEncodePolicy): """Base 64 message encoding policy for binary messages. - + Encodes binary messages to base 64. If the input content is not bytes, a TypeError will be raised. """ @@ -149,7 +117,7 @@ def encode(self, content): class BinaryBase64DecodePolicy(MessageDecodePolicy): """Message decoding policy for base 64-encoded messages into bytes. - + Decodes base64-encoded messages to bytes. If the input content is not valid base 64, a DecodeError will be raised. """ @@ -167,7 +135,7 @@ def decode(self, content, response): class TextXMLEncodePolicy(MessageEncodePolicy): """XML message encoding policy for text messages. - + Encodes text (unicode) messages to XML. If the input content is not text, a TypeError will be raised. """ @@ -180,7 +148,7 @@ def encode(self, content): class TextXMLDecodePolicy(MessageDecodePolicy): """Message decoding policy for XML-encoded messages into text. - + Decodes XML-encoded messages to text (unicode). """ diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py index e9de0de09a94..43b1529ce988 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py @@ -100,6 +100,9 @@ def _add_authorization_header(self, request, string_to_sign): raise _wrap_exception(ex, AzureSigningError) def on_request(self, request, **kwargs): + if not 'content-type' in request.http_request.headers: + request.http_request.headers['content-type'] = 'application/xml; charset=utf-8' + string_to_sign = \ self._get_verb(request) + \ self._get_headers( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 2f6148afeb11..5d338689908d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -10,9 +10,10 @@ ) import logging try: - from urllib.parse import parse_qs + from urllib.parse import parse_qs, quote except ImportError: from urlparse import parse_qs # type: ignore + from urllib2 import quote # type: ignore import six @@ -87,9 +88,7 @@ def __init__( self.require_encryption = kwargs.get('require_encryption', False) self.key_encryption_key = kwargs.get('key_encryption_key') self.key_resolver_function = kwargs.get('key_resolver_function') - - self._config, self._pipeline = create_pipeline( - self.credential, storage_sdk=service, hosts=self._hosts, **kwargs) + self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs) def __enter__(self): self._client.__enter__() @@ -145,6 +144,38 @@ def _format_query_string(self, sas_token, credential, snapshot=None, share_snaps credential = None return query_str.rstrip('?&'), credential + def _create_pipeline(self, credential, **kwargs): + # type: (Any, **Any) -> Tuple[Configuration, Pipeline] + credential_policy = None + if hasattr(credential, 'get_token'): + credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) + elif isinstance(credential, SharedKeyCredentialPolicy): + credential_policy = credential + elif credential is not None: + raise TypeError("Unsupported credential: {}".format(credential)) + + config = kwargs.get('_configuration') or create_configuration(**kwargs) + if kwargs.get('_pipeline'): + return config, kwargs['_pipeline'] + config.transport = kwargs.get('transport') # type: ignore + if not config.transport: + config.transport = RequestsTransport(config) + policies = [ + QueueMessagePolicy(), + config.headers_policy, + config.user_agent_policy, + StorageContentValidation(), + StorageRequestHook(**kwargs), + credential_policy, + ContentDecodePolicy(), + RedirectPolicy(**kwargs), + StorageHosts(hosts=self._hosts, **kwargs), + config.retry_policy, + config.logging_policy, + StorageResponseHook(**kwargs), + ] + return config, Pipeline(config.transport, policies=policies) + def format_shared_key_credential(account, credential): if isinstance(credential, six.string_types): @@ -219,7 +250,6 @@ def create_configuration(**kwargs): config.headers_policy = StorageHeadersPolicy(**kwargs) config.user_agent_policy = StorageUserAgentPolicy(**kwargs) config.retry_policy = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) - config.redirect_policy = RedirectPolicy(**kwargs) config.logging_policy = StorageLoggingPolicy(**kwargs) config.proxy_policy = ProxyPolicy(**kwargs) @@ -244,43 +274,10 @@ def create_configuration(**kwargs): return config -def create_pipeline(credential, **kwargs): - # type: (Any, **Any) -> Tuple[Configuration, Pipeline] - credential_policy = None - if hasattr(credential, 'get_token'): - credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) - elif isinstance(credential, SharedKeyCredentialPolicy): - credential_policy = credential - elif credential is not None: - raise TypeError("Unsupported credential: {}".format(credential)) - - config = kwargs.get('_configuration') or create_configuration(**kwargs) - if kwargs.get('_pipeline'): - return config, kwargs['_pipeline'] - transport = kwargs.get('transport') # type: HttpTransport - if not transport: - transport = RequestsTransport(config) - policies = [ - QueueMessagePolicy(), - config.headers_policy, - config.user_agent_policy, - StorageContentValidation(), - StorageRequestHook(**kwargs), - credential_policy, - ContentDecodePolicy(), - config.redirect_policy, - StorageHosts(**kwargs), - config.retry_policy, - config.logging_policy, - StorageResponseHook(**kwargs), - ] - return config, Pipeline(transport, policies=policies) - - def parse_query(query_str): sas_values = QueryStringConstants.to_list() parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()} - sas_params = ["{}={}".format(k, v) for k, v in parsed_query.items() if k in sas_values] + sas_params = ["{}={}".format(k, quote(v)) for k, v in parsed_query.items() if k in sas_values] sas_token = None if sas_params: sas_token = '&'.join(sas_params) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py new file mode 100644 index 000000000000..19d314f87892 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -0,0 +1,86 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import ( # pylint: disable=unused-import + Union, Optional, Any, Iterable, Dict, List, Type, Tuple, + TYPE_CHECKING +) +import logging + +from azure.core.pipeline import AsyncPipeline +try: + from azure.core.pipeline.transport import AioHttpTransport as AsyncTransport +except ImportError: + from azure.core.pipeline.transport import AsyncioRequestsTransport as AsyncTransport +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + BearerTokenCredentialPolicy, + AsyncRedirectPolicy) + +from .constants import STORAGE_OAUTH_SCOPE, DEFAULT_SOCKET_TIMEOUT +from .authentication import SharedKeyCredentialPolicy +from .base_client import create_configuration +from .policies import ( + StorageContentValidation, + StorageRequestHook, + StorageHosts, + QueueMessagePolicy) +from .policies_async import AsyncStorageResponseHook + +if TYPE_CHECKING: + from azure.core.pipeline import Pipeline + from azure.core import Configuration +_LOGGER = logging.getLogger(__name__) + + +class AsyncStorageAccountHostsMixin(object): + + def __enter__(self): + raise TypeError("Async client only supports 'async with'.") + + def __exit__(self, *args): + pass + + async def __aenter__(self): + await self._client.__aenter__() + return self + + async def __aexit__(self, *args): + await self._client.__aexit__(*args) + + def _create_pipeline(self, credential, **kwargs): + # type: (Any, **Any) -> Tuple[Configuration, Pipeline] + credential_policy = None + if hasattr(credential, 'get_token'): + credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) + elif isinstance(credential, SharedKeyCredentialPolicy): + credential_policy = credential + elif credential is not None: + raise TypeError("Unsupported credential: {}".format(credential)) + + if 'connection_timeout' not in kwargs: + kwargs['connection_timeout'] = DEFAULT_SOCKET_TIMEOUT[0] # type: ignore + config = kwargs.get('_configuration') or create_configuration(**kwargs) + if kwargs.get('_pipeline'): + return config, kwargs['_pipeline'] + config.transport = kwargs.get('transport') # type: ignore + if not config.transport: + config.transport = AsyncTransport(config) + policies = [ + QueueMessagePolicy(), + config.headers_policy, + config.user_agent_policy, + StorageContentValidation(), + StorageRequestHook(**kwargs), + credential_policy, + ContentDecodePolicy(), + AsyncRedirectPolicy(**kwargs), + StorageHosts(hosts=self._hosts, **kwargs), # type: ignore + config.retry_policy, + config.logging_policy, + AsyncStorageResponseHook(**kwargs), + ] + return config, AsyncPipeline(config.transport, policies=policies) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/download_chunking.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/downloads.py similarity index 97% rename from sdk/storage/azure-storage-queue/azure/storage/queue/_shared/download_chunking.py rename to sdk/storage/azure-storage-queue/azure/storage/queue/_shared/downloads.py index e923a992e314..1d46ffc95293 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/download_chunking.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/downloads.py @@ -10,7 +10,6 @@ from azure.core.exceptions import HttpResponseError -from .models import ModifiedAccessConditions from .request_handlers import validate_and_format_range_headers from .response_handlers import process_storage_error, parse_length_from_content_range from .encryption import decrypt_blob @@ -40,22 +39,25 @@ def process_range_and_offset(start_range, end_range, length, encryption): def process_content(data, start_offset, end_offset, encryption): - if encryption.get('key') is not None or encryption.get('resolver') is not None: + if data is None: + raise ValueError("Response cannot be None.") + content = b"".join(list(data)) + if content and encryption.get('key') is not None or encryption.get('resolver') is not None: try: return decrypt_blob( encryption.get('required'), encryption.get('key'), encryption.get('resolver'), - data, + content, start_offset, - end_offset) + end_offset, + data.response.headers) except Exception as error: raise HttpResponseError( message="Decryption failed.", response=data.response, error=error) - else: - return b"".join(list(data)) + return content class _ChunkDownloader(object): @@ -209,7 +211,7 @@ def _write_to_stream(self, chunk_data, chunk_start): self.stream.write(chunk_data) -class StorageStreamDownloader(object): +class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes """A streaming object to download from Azure Storage. The stream downloader can iterated, or download to open file or stream @@ -291,14 +293,14 @@ def __iter__(self): # Use the length unless it is over the end of the file data_end = min(self.file_size, self.length + 1) - downloader = SequentialBlobChunkDownloader( + downloader = SequentialChunkDownloader( service=self.service, total_size=self.download_size, chunk_size=self.config.max_chunk_get_size, current_progress=self.first_get_size, start_range=self.initial_range[1] + 1, # start where the first download ended end_range=data_end, - stream=stream, + stream=None, validate_content=self.validate_content, encryption_options=self.encryption_options, use_location=self.location_mode, diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/downloads_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/downloads_async.py new file mode 100644 index 000000000000..37adcd93960a --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/downloads_async.py @@ -0,0 +1,425 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import asyncio +from io import BytesIO +from itertools import islice + +from azure.core.exceptions import HttpResponseError + +from .request_handlers import validate_and_format_range_headers +from .response_handlers import process_storage_error, parse_length_from_content_range +from .encryption import decrypt_blob +from .downloads import process_range_and_offset + + +async def process_content(data, start_offset, end_offset, encryption): + if data is None: + raise ValueError("Response cannot be None.") + content = data.response.body + if encryption.get('key') is not None or encryption.get('resolver') is not None: + try: + return decrypt_blob( + encryption.get('required'), + encryption.get('key'), + encryption.get('resolver'), + content, + start_offset, + end_offset, + data.response.headers) + except Exception as error: + raise HttpResponseError( + message="Decryption failed.", + response=data.response, + error=error) + return content + + +class _AsyncChunkDownloader(object): # pylint: disable=too-many-instance-attributes + + def __init__( + self, service=None, + total_size=None, + chunk_size=None, + current_progress=None, + start_range=None, + end_range=None, + stream=None, + parallel=None, + validate_content=None, + encryption_options=None, + **kwargs): + + self.service = service + + # information on the download range/chunk size + self.chunk_size = chunk_size + self.total_size = total_size + self.start_index = start_range + self.end_index = end_range + + # the destination that we will write to + self.stream = stream + self.stream_lock = asyncio.Lock() if parallel else None + self.progress_lock = asyncio.Lock() if parallel else None + + # for a parallel download, the stream is always seekable, so we note down the current position + # in order to seek to the right place when out-of-order chunks come in + self.stream_start = stream.tell() if parallel else None + + # download progress so far + self.progress_total = current_progress + + # encryption + self.encryption_options = encryption_options + + # parameters for each get operation + self.validate_content = validate_content + self.request_options = kwargs + + def _calculate_range(self, chunk_start): + if chunk_start + self.chunk_size > self.end_index: + chunk_end = self.end_index + else: + chunk_end = chunk_start + self.chunk_size + return chunk_start, chunk_end + + def get_chunk_offsets(self): + index = self.start_index + while index < self.end_index: + yield index + index += self.chunk_size + + async def process_chunk(self, chunk_start): + chunk_start, chunk_end = self._calculate_range(chunk_start) + chunk_data = await self._download_chunk(chunk_start, chunk_end) + length = chunk_end - chunk_start + if length > 0: + await self._write_to_stream(chunk_data, chunk_start) + await self._update_progress(length) + + async def yield_chunk(self, chunk_start): + chunk_start, chunk_end = self._calculate_range(chunk_start) + return await self._download_chunk(chunk_start, chunk_end) + + async def _update_progress(self, length): + if self.progress_lock: + async with self.progress_lock: + self.progress_total += length + else: + self.progress_total += length + + async def _write_to_stream(self, chunk_data, chunk_start): + if self.stream_lock: + async with self.stream_lock: + self.stream.seek(self.stream_start + (chunk_start - self.start_index)) + self.stream.write(chunk_data) + else: + self.stream.write(chunk_data) + + async def _download_chunk(self, chunk_start, chunk_end): + download_range, offset = process_range_and_offset( + chunk_start, chunk_end, chunk_end, self.encryption_options) + range_header, range_validation = validate_and_format_range_headers( + download_range[0], + download_range[1] - 1, + check_content_md5=self.validate_content) + + try: + _, response = await self.service.download( + range=range_header, + range_get_content_md5=range_validation, + validate_content=self.validate_content, + data_stream_total=self.total_size, + download_stream_current=self.progress_total, + **self.request_options) + except HttpResponseError as error: + process_storage_error(error) + + chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options) + + # This makes sure that if_match is set so that we can validate + # that subsequent downloads are to an unmodified blob + if self.request_options.get('modified_access_conditions'): + self.request_options['modified_access_conditions'].if_match = response.properties.etag + + return chunk_data + +class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes + """A streaming object to download from Azure Storage. + The stream downloader can iterated, or download to open file or stream + over multiple threads. + """ + + def __init__( + self, service=None, + config=None, + offset=None, + length=None, + validate_content=None, + encryption_options=None, + **kwargs): + self.service = service + self.config = config + self.offset = offset + self.length = length + self.validate_content = validate_content + self.encryption_options = encryption_options or {} + self.request_options = kwargs + self.location_mode = None + self._download_complete = False + self._current_content = None + self._iter_downloader = None + self._iter_chunks = None + + # The service only provides transactional MD5s for chunks under 4MB. + # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # chunk so a transactional MD5 can be retrieved. + self.first_get_size = self.config.max_single_get_size if not self.validate_content \ + else self.config.max_chunk_get_size + initial_request_start = self.offset if self.offset is not None else 0 + if self.length is not None and self.length - self.offset < self.first_get_size: + initial_request_end = self.length + else: + initial_request_end = initial_request_start + self.first_get_size - 1 + + self.initial_range, self.initial_offset = process_range_and_offset( + initial_request_start, initial_request_end, self.length, self.encryption_options) + self.download_size = None + self.file_size = None + self.response = None + self.properties = None + + def __len__(self): + return self.download_size + + def __iter__(self): + raise TypeError("Async stream must be iterated asynchronously.") + + def __aiter__(self): + return self + + async def __anext__(self): + """Iterate through responses.""" + if self._current_content is None: + if self.download_size == 0: + self._current_content = b"" + else: + self._current_content = await process_content( + self.response, self.initial_offset[0], self.initial_offset[1], self.encryption_options) + if not self._download_complete: + data_end = self.file_size + if self.length is not None: + # Use the length unless it is over the end of the file + data_end = min(self.file_size, self.length + 1) + self._iter_downloader = _AsyncChunkDownloader( + service=self.service, + total_size=self.download_size, + chunk_size=self.config.max_chunk_get_size, + current_progress=self.first_get_size, + start_range=self.initial_range[1] + 1, # start where the first download ended + end_range=data_end, + stream=None, + parallel=False, + validate_content=self.validate_content, + encryption_options=self.encryption_options, + use_location=self.location_mode, + **self.request_options) + self._iter_chunks = self._iter_downloader.get_chunk_offsets() + elif self._download_complete: + raise StopAsyncIteration("Download complete") + else: + try: + chunk = next(self._iter_chunks) + except StopIteration: + raise StopAsyncIteration("DownloadComplete") + self._current_content = await self._iter_downloader.yield_chunk(chunk) + + return self._current_content + + async def setup(self, extra_properties=None): + if self.response: + raise ValueError("Download stream already initialized.") + self.response = await self._initial_request() + self.properties = self.response.properties + + # Set the content length to the download size instead of the size of + # the last range + self.properties.size = self.download_size + + # Overwrite the content range to the user requested range + self.properties.content_range = 'bytes {0}-{1}/{2}'.format(self.offset, self.length, self.file_size) + + # Set additional properties according to download type + if extra_properties: + for prop, value in extra_properties.items(): + setattr(self.properties, prop, value) + + # Overwrite the content MD5 as it is the MD5 for the last range instead + # of the stored MD5 + # TODO: Set to the stored MD5 when the service returns this + self.properties.content_md5 = None + + async def _initial_request(self): + range_header, range_validation = validate_and_format_range_headers( + self.initial_range[0], + self.initial_range[1], + start_range_required=False, + end_range_required=False, + check_content_md5=self.validate_content) + + try: + location_mode, response = await self.service.download( + range=range_header, + range_get_content_md5=range_validation, + validate_content=self.validate_content, + data_stream_total=None, + download_stream_current=0, + **self.request_options) + + # Check the location we read from to ensure we use the same one + # for subsequent requests. + self.location_mode = location_mode + + # Parse the total file size and adjust the download size if ranges + # were specified + self.file_size = parse_length_from_content_range(response.properties.content_range) + if self.length is not None: + # Use the length unless it is over the end of the file + self.download_size = min(self.file_size, self.length - self.offset + 1) + elif self.offset is not None: + self.download_size = self.file_size - self.offset + else: + self.download_size = self.file_size + + except HttpResponseError as error: + if self.offset is None and error.response.status_code == 416: + # Get range will fail on an empty file. If the user did not + # request a range, do a regular get request in order to get + # any properties. + try: + _, response = await self.service.download( + validate_content=self.validate_content, + data_stream_total=0, + download_stream_current=0, + **self.request_options) + except HttpResponseError as error: + process_storage_error(error) + + # Set the download size to empty + self.download_size = 0 + self.file_size = 0 + else: + process_storage_error(error) + + # If the file is small, the download is complete at this point. + # If file size is large, download the rest of the file in chunks. + if response.properties.size != self.download_size: + # Lock on the etag. This can be overriden by the user by specifying '*' + if self.request_options.get('modified_access_conditions'): + if not self.request_options['modified_access_conditions'].if_match: + self.request_options['modified_access_conditions'].if_match = response.properties.etag + else: + self._download_complete = True + return response + + async def content_as_bytes(self, max_connections=1): + """Download the contents of this file. + This operation is blocking until all data is downloaded. + :param int max_connections: + The number of parallel connections with which to download. + :rtype: bytes + """ + stream = BytesIO() + await self.download_to_stream(stream, max_connections=max_connections) + return stream.getvalue() + + async def content_as_text(self, max_connections=1, encoding='UTF-8'): + """Download the contents of this file, and decode as text. + This operation is blocking until all data is downloaded. + :param int max_connections: + The number of parallel connections with which to download. + :rtype: str + """ + content = await self.content_as_bytes(max_connections=max_connections) + return content.decode(encoding) + + async def download_to_stream(self, stream, max_connections=1): + """Download the contents of this file to a stream. + :param stream: + The stream to download to. This can be an open file-handle, + or any writable stream. The stream must be seekable if the download + uses more than one parallel connection. + :returns: The properties of the downloaded file. + :rtype: Any + """ + if self._iter_downloader: + raise ValueError("Stream is currently being iterated.") + + # the stream must be seekable if parallel download is required + parallel = max_connections > 1 + if parallel: + error_message = "Target stream handle must be seekable." + if sys.version_info >= (3,) and not stream.seekable(): + raise ValueError(error_message) + + try: + stream.seek(stream.tell()) + except (NotImplementedError, AttributeError): + raise ValueError(error_message) + + if self.download_size == 0: + content = b"" + else: + content = await process_content( + self.response, self.initial_offset[0], self.initial_offset[1], self.encryption_options) + + # Write the content to the user stream + if content is not None: + stream.write(content) + if self._download_complete: + return self.properties + + data_end = self.file_size + if self.length is not None: + # Use the length unless it is over the end of the file + data_end = min(self.file_size, self.length + 1) + + downloader = _AsyncChunkDownloader( + service=self.service, + total_size=self.download_size, + chunk_size=self.config.max_chunk_get_size, + current_progress=self.first_get_size, + start_range=self.initial_range[1] + 1, # start where the first download ended + end_range=data_end, + stream=stream, + parallel=parallel, + validate_content=self.validate_content, + encryption_options=self.encryption_options, + use_location=self.location_mode, + **self.request_options) + + dl_tasks = downloader.get_chunk_offsets() + running_futures = [ + asyncio.ensure_future(downloader.process_chunk(d)) + for d in islice(dl_tasks, 0, max_connections) + ] + while running_futures: + # Wait for some download to finish before adding a new one + _done, running_futures = await asyncio.wait( + running_futures, return_when=asyncio.FIRST_COMPLETED) + try: + next_chunk = next(dl_tasks) + except StopIteration: + break + else: + running_futures.add(asyncio.ensure_future(downloader.process_chunk(next_chunk))) + + if running_futures: + # Wait for the remaining downloads to finish + await asyncio.wait(running_futures) + return self.properties diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/encryption.py index b1178eaa9262..10077fedcfb7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/encryption.py @@ -128,7 +128,6 @@ def __init__(self, content_encryption_IV, encryption_agent, wrapped_content_key, def _generate_encryption_data_dict(kek, cek, iv): ''' Generates and returns the encryption metadata as a dict. - :param object kek: The key encryption key. See calling functions for more information. :param bytes cek: The content encryption key. :param bytes iv: The initialization vector. @@ -162,7 +161,6 @@ def _dict_to_encryption_data(encryption_data_dict): ''' Converts the specified dictionary to an EncryptionData object for eventual use in decryption. - :param dict encryption_data_dict: The dictionary containing the encryption data. :return: an _EncryptionData object built from the dictionary. @@ -198,7 +196,6 @@ def _dict_to_encryption_data(encryption_data_dict): def _generate_AES_CBC_cipher(cek, iv): ''' Generates and returns an encryption cipher for AES CBC using the given cek and iv. - :param bytes[] cek: The content encryption key for the cipher. :param bytes[] iv: The initialization vector for the cipher. :return: A cipher for encrypting in AES256 CBC. @@ -259,7 +256,6 @@ def _decrypt_message(message, encryption_data, key_encryption_key=None, resolver Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. Unwraps the content-encryption-key using the user-provided or resolved key-encryption-key (kek). Returns the original plaintex. - :param str message: The ciphertext to be decrypted. :param _EncryptionData encryption_data: @@ -303,7 +299,6 @@ def encrypt_blob(blob, key_encryption_key): Returns a json-formatted string containing the encryption metadata. This method should only be used when a blob is small enough for single shot upload. Encrypting larger blobs is done as a part of the upload_blob_chunks method. - :param bytes blob: The blob to be encrypted. :param object key_encryption_key: @@ -342,7 +337,6 @@ def encrypt_blob(blob, key_encryption_key): def generate_blob_encryption_data(key_encryption_key): ''' Generates the encryption_metadata for the blob. - :param bytes key_encryption_key: The key-encryption-key used to wrap the cek associate with this blob. :return: A tuple containing the cek and iv for this blob as well as the @@ -366,10 +360,9 @@ def generate_blob_encryption_data(key_encryption_key): def decrypt_blob(require_encryption, key_encryption_key, key_resolver, - response, start_offset, end_offset): + content, start_offset, end_offset, response_headers): ''' Decrypts the given blob contents and returns only the requested range. - :param bool require_encryption: Whether or not the calling blob service requires objects to be decrypted. :param object key_encryption_key: @@ -383,14 +376,8 @@ def decrypt_blob(require_encryption, key_encryption_key, key_resolver, :return: The decrypted blob content. :rtype: bytes ''' - if response is None: - raise ValueError("Response cannot be None.") - content = b"".join(list(response)) - if not content: - return content - try: - encryption_data = _dict_to_encryption_data(loads(response.response.headers['x-ms-meta-encryptiondata'])) + encryption_data = _dict_to_encryption_data(loads(response_headers['x-ms-meta-encryptiondata'])) except: # pylint: disable=bare-except if require_encryption: raise ValueError( @@ -402,12 +389,12 @@ def decrypt_blob(require_encryption, key_encryption_key, key_resolver, if encryption_data.encryption_agent.encryption_algorithm != _EncryptionAlgorithm.AES_CBC_256: raise ValueError('Specified encryption algorithm is not supported.') - blob_type = response.response.headers['x-ms-blob-type'] + blob_type = response_headers['x-ms-blob-type'] iv = None unpad = False - if 'content-range' in response.response.headers: - content_range = response.response.headers['content-range'] + if 'content-range' in response_headers: + content_range = response_headers['content-range'] # Format: 'bytes x-y/size' # Ignore the word 'bytes' @@ -463,7 +450,6 @@ def encrypt_queue_message(message, key_encryption_key): Encrypts the given plain text message using AES256 in CBC mode with 128 bit padding. Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encrypted message and the encryption metadata. - :param object message: The plain text messge to be encrypted. :param object key_encryption_key: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 30e4506254d6..7185141649f9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -210,30 +210,6 @@ def get(self, key, default=None): return default -class ModifiedAccessConditions(object): - """Additional parameters for a set of operations. - - :param if_modified_since: Specify this header value to operate only on a - blob if it has been modified since the specified date/time. - :type if_modified_since: datetime - :param if_unmodified_since: Specify this header value to operate only on a - blob if it has not been modified since the specified date/time. - :type if_unmodified_since: datetime - :param if_match: Specify an ETag value to operate only on blobs with a - matching value. - :type if_match: str - :param if_none_match: Specify an ETag value to operate only on blobs - without a matching value. - :type if_none_match: str - """ - - def __init__(self, **kwargs): - self.if_modified_since = kwargs.get('if_modified_since', None) - self.if_unmodified_since = kwargs.get('if_unmodified_since', None) - self.if_match = kwargs.get('if_match', None) - self.if_none_match = kwargs.get('if_none_match', None) - - class LocationMode(object): """ Specifies the location the request should be sent to. This mode only applies diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 5b0212fd9090..f3e7c182d37a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -61,6 +61,20 @@ def encode_base64(data): return encoded.decode('utf-8') +def is_exhausted(settings): + """Are we out of retries?""" + retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) + retry_counts = list(filter(None, retry_counts)) + if not retry_counts: + return False + return min(retry_counts) < 0 + + +def retry_hook(settings, **kwargs): + if settings['hook']: + settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) + + def is_retry(response, mode): """Is this method/status code retryable? (Based on whitelists and control variables such as the number of total retries to allow, whether to @@ -428,15 +442,6 @@ def sleep(self, settings, transport): return transport.sleep(backoff) - def is_exhausted(self, settings): # pylint: disable=no-self-use - """Are we out of retries?""" - retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) - retry_counts = list(filter(None, retry_counts)) - if not retry_counts: - return False - - return min(retry_counts) < 0 - def increment(self, settings, request, response=None, error=None): """Increment the retry counters. @@ -467,7 +472,7 @@ def increment(self, settings, request, response=None, error=None): settings['status'] -= 1 settings['history'].append(RequestHistory(request, http_response=response)) - if not self.is_exhausted(settings): + if not is_exhausted(settings): if request.method not in ['PUT'] and settings['retry_secondary']: self._set_next_host_location(settings, request) @@ -482,13 +487,6 @@ def increment(self, settings, request, response=None, error=None): except UnsupportedOperation: # if body is not seekable, then retry would not work return False - if settings['hook']: - settings['hook']( - request=request, - response=response, - error=error, - retry_count=settings['count'], - location_mode=settings['mode']) settings['count'] += 1 return True return False @@ -506,14 +504,23 @@ def send(self, request): request=request.http_request, response=response.http_response) if retries_remaining: + retry_hook( + retry_settings, + request=request.http_request, + response=response.http_response, + error=None) self.sleep(retry_settings, request.context.transport) - continue break except AzureError as err: retries_remaining = self.increment( retry_settings, request=request.http_request, error=err) if retries_remaining: + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err) self.sleep(retry_settings, request.context.transport) continue raise err diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py new file mode 100644 index 000000000000..b84ba562b948 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -0,0 +1,229 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import random +import logging +from typing import Any, TYPE_CHECKING + +from azure.core.pipeline.policies import AsyncHTTPPolicy +from azure.core.exceptions import AzureError + +from .policies import is_retry, StorageRetryPolicy + +if TYPE_CHECKING: + from azure.core.pipeline import PipelineRequest, PipelineResponse + + +_LOGGER = logging.getLogger(__name__) + + +async def retry_hook(settings, **kwargs): + if settings['hook']: + if asyncio.iscoroutine(settings['hook']): + await settings['hook']( + retry_count=settings['count'] - 1, + location_mode=settings['mode'], + **kwargs) + else: + settings['hook']( + retry_count=settings['count'] - 1, + location_mode=settings['mode'], + **kwargs) + + +class AsyncStorageResponseHook(AsyncHTTPPolicy): + + def __init__(self, **kwargs): # pylint: disable=unused-argument + self._response_callback = kwargs.get('raw_response_hook') + super(AsyncStorageResponseHook, self).__init__() + + async def send(self, request): + # type: (PipelineRequest) -> PipelineResponse + data_stream_total = request.context.get('data_stream_total') or \ + request.context.options.pop('data_stream_total', None) + download_stream_current = request.context.get('download_stream_current') or \ + request.context.options.pop('download_stream_current', None) + upload_stream_current = request.context.get('upload_stream_current') or \ + request.context.options.pop('upload_stream_current', None) + response_callback = request.context.get('response_callback') or \ + request.context.options.pop('raw_response_hook', self._response_callback) + + response = await self.next.send(request) + await response.http_response.load_body() + response.http_response.internal_response.body = response.http_response.body() + + will_retry = is_retry(response, request.context.options.get('mode')) + if not will_retry and download_stream_current is not None: + download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + if data_stream_total is None: + content_range = response.http_response.headers.get('Content-Range') + if content_range: + data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + else: + data_stream_total = download_stream_current + elif not will_retry and upload_stream_current is not None: + upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + for pipeline_obj in [request, response]: + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current + if response_callback: + if asyncio.iscoroutine(response_callback): + await response_callback(response) + else: + response_callback(response) + request.context['response_callback'] = response_callback + return response + +class AsyncStorageRetryPolicy(StorageRetryPolicy): + """ + The base class for Exponential and Linear retries containing shared code. + """ + + async def sleep(self, settings, transport): + backoff = self.get_backoff_time(settings) + if not backoff or backoff < 0: + return + await transport.sleep(backoff) + + async def send(self, request): + retries_remaining = True + response = None + retry_settings = self.configure_retries(request) + while retries_remaining: + try: + response = await self.next.send(request) + if is_retry(response, retry_settings['mode']): + retries_remaining = self.increment( + retry_settings, + request=request.http_request, + response=response.http_response) + if retries_remaining: + await retry_hook( + retry_settings, + request=request.http_request, + response=response.http_response, + error=None) + await self.sleep(retry_settings, request.context.transport) + continue + break + except AzureError as err: + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err) + if retries_remaining: + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err) + await self.sleep(retry_settings, request.context.transport) + continue + raise err + if retry_settings['history']: + response.context['history'] = retry_settings['history'] + response.http_response.location_mode = retry_settings['mode'] + return response + + +class NoRetry(AsyncStorageRetryPolicy): + + def __init__(self): + super(NoRetry, self).__init__(retry_total=0) + + def increment(self, *args, **kwargs): # pylint: disable=unused-argument,arguments-differ + return False + + +class ExponentialRetry(AsyncStorageRetryPolicy): + """Exponential retry.""" + + def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, + retry_to_secondary=False, random_jitter_range=3, **kwargs): + ''' + Constructs an Exponential retry object. The initial_backoff is used for + the first retry. Subsequent retries are retried after initial_backoff + + increment_power^retry_count seconds. For example, by default the first retry + occurs after 15 seconds, the second after (15+3^1) = 18 seconds, and the + third after (15+3^2) = 24 seconds. + + :param int initial_backoff: + The initial backoff interval, in seconds, for the first retry. + :param int increment_base: + The base, in seconds, to increment the initial_backoff by after the + first retry. + :param int max_attempts: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + ''' + self.initial_backoff = initial_backoff + self.increment_base = increment_base + self.random_jitter_range = random_jitter_range + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings): + """ + Calculates how long to sleep before retrying. + + :return: + An integer indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: int or None + """ + random_generator = random.Random() + backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + random_range_end = backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) + + +class LinearRetry(AsyncStorageRetryPolicy): + """Linear retry.""" + + def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_jitter_range=3, **kwargs): + """ + Constructs a Linear retry object. + + :param int backoff: + The backoff interval, in seconds, between retries. + :param int max_attempts: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + """ + self.backoff = backoff + self.random_jitter_range = random_jitter_range + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings): + """ + Calculates how long to sleep before retrying. + + :return: + An integer indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: int or None + """ + random_generator = random.Random() + # the backoff interval normally does not change, however there is the possibility + # that it was modified by accessing the property directly after initializing the object + random_range_start = self.backoff - self.random_jitter_range \ + if self.backoff > self.random_jitter_range else 0 + random_range_end = self.backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/upload_chunking.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads.py similarity index 56% rename from sdk/storage/azure-storage-queue/azure/storage/queue/_shared/upload_chunking.py rename to sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads.py index 86a3f4224ffa..2b269fb1d0ba 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/upload_chunking.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads.py @@ -5,14 +5,14 @@ # -------------------------------------------------------------------------- # pylint: disable=no-self-use +from concurrent import futures from io import (BytesIO, IOBase, SEEK_CUR, SEEK_END, SEEK_SET, UnsupportedOperation) from threading import Lock - +from itertools import islice from math import ceil import six -from .models import ModifiedAccessConditions from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers @@ -23,93 +23,66 @@ _ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM = '{0} should be a seekable file-like/io.IOBase type stream object.' -def upload_file_chunks(file_service, file_size, block_size, stream, max_connections, - validate_content, timeout, **kwargs): - uploader = FileChunkUploader( - file_service, - file_size, - block_size, - stream, - max_connections > 1, - validate_content, - timeout, - **kwargs - ) - if max_connections > 1: - import concurrent.futures - executor = concurrent.futures.ThreadPoolExecutor(max_connections) - range_ids = list(executor.map(uploader.process_chunk, uploader.get_chunk_offsets())) - else: - if file_size is not None: - range_ids = [uploader.process_chunk(start) for start in uploader.get_chunk_offsets()] +def _parallel_uploads(executor, uploader, pending, running): + range_ids = [] + while True: + # Wait for some download to finish before adding a new one + done, running = futures.wait(running, return_when=futures.FIRST_COMPLETED) + range_ids.extend([chunk.result() for chunk in done]) + try: + next_chunk = next(pending) + except StopIteration: + break else: - range_ids = uploader.process_all_unknown_size() - return range_ids + running.add(executor.submit(uploader.process_chunk, next_chunk)) + # Wait for the remaining uploads to finish + done, _running = futures.wait(running) + range_ids.extend([chunk.result() for chunk in done]) + return range_ids -def upload_blob_chunks(blob_service, blob_size, block_size, stream, max_connections, validate_content, # pylint: disable=too-many-locals - access_conditions, uploader_class, append_conditions=None, modified_access_conditions=None, - timeout=None, content_encryption_key=None, initialization_vector=None, **kwargs): - encryptor, padder = get_blob_encryptor_and_padder( - content_encryption_key, - initialization_vector, - uploader_class is not PageBlobChunkUploader) +def upload_data_chunks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_connections=None, + stream=None, + validate_content=None, + encryption_options=None, + **kwargs): + + if encryption_options: + encryptor, padder = get_blob_encryptor_and_padder( + encryption_options.get('key'), + encryption_options.get('vector'), + uploader_class is not PageBlobChunkUploader) + kwargs['encryptor'] = encryptor + kwargs['padder'] = padder + + parallel = max_connections > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None uploader = uploader_class( - blob_service, - blob_size, - block_size, - stream, - max_connections > 1, - validate_content, - access_conditions, - append_conditions, - timeout, - encryptor, - padder, - **kwargs - ) - - # Access conditions do not work with parallelism - if max_connections > 1: - uploader.modified_access_conditions = None - else: - uploader.modified_access_conditions = modified_access_conditions - - if max_connections > 1: - import concurrent.futures - from threading import BoundedSemaphore - - # Ensures we bound the chunking so we only buffer and submit 'max_connections' - # amount of work items to the executor. This is necessary as the executor queue will keep - # accepting submitted work items, which results in buffering all the blocks if - # the max_connections + 1 ensures the next chunk is already buffered and ready for when - # the worker thread is available. - chunk_throttler = BoundedSemaphore(max_connections + 1) - - executor = concurrent.futures.ThreadPoolExecutor(max_connections) - futures = [] - running_futures = [] - - # Check for exceptions and fail fast. - for chunk in uploader.get_chunk_streams(): - for f in running_futures: - if f.done(): - if f.exception(): - raise f.exception() - running_futures.remove(f) - - chunk_throttler.acquire() - future = executor.submit(uploader.process_chunk, chunk) - - # Calls callback upon completion (even if the callback was added after the Future task is done). - future.add_done_callback(lambda x: chunk_throttler.release()) - futures.append(future) - running_futures.append(future) - - # result() will wait until completion and also raise any exceptions that may have been set. - range_ids = [f.result() for f in futures] + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + validate_content=validate_content, + **kwargs) + + if parallel: + executor = futures.ThreadPoolExecutor(max_connections) + upload_tasks = uploader.get_chunk_streams() + running_futures = [ + executor.submit(uploader.process_chunk, u) + for u in islice(upload_tasks, 0, max_connections) + ] + range_ids = _parallel_uploads(executor, uploader, upload_tasks, running_futures) else: range_ids = [uploader.process_chunk(result) for result in uploader.get_chunk_streams()] @@ -118,59 +91,55 @@ def upload_blob_chunks(blob_service, blob_size, block_size, stream, max_connecti return uploader.response_headers -def upload_blob_substream_blocks(blob_service, blob_size, block_size, stream, max_connections, - validate_content, access_conditions, uploader_class, - append_conditions=None, modified_access_conditions=None, timeout=None, **kwargs): - +def upload_substream_blocks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_connections=None, + stream=None, + **kwargs): + parallel = max_connections > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None uploader = uploader_class( - blob_service, - blob_size, - block_size, - stream, - max_connections > 1, - validate_content, - access_conditions, - append_conditions, - timeout, - None, - None, - **kwargs - ) - # ETag matching does not work with parallelism as a ranged upload may start - # before the previous finishes and provides an etag - if max_connections > 1: - uploader.modified_access_conditions = None - else: - uploader.modified_access_conditions = modified_access_conditions - - if max_connections > 1: - import concurrent.futures - executor = concurrent.futures.ThreadPoolExecutor(max_connections) - range_ids = list(executor.map(uploader.process_substream_block, uploader.get_substream_blocks())) - else: - range_ids = [uploader.process_substream_block(result) for result in uploader.get_substream_blocks()] - - return range_ids - - -class _BlobChunkUploader(object): # pylint: disable=too-many-instance-attributes - - def __init__(self, blob_service, blob_size, chunk_size, stream, parallel, validate_content, - access_conditions, append_conditions, timeout, encryptor, padder, **kwargs): - self.blob_service = blob_service - self.blob_size = blob_size + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + **kwargs) + + if parallel: + executor = futures.ThreadPoolExecutor(max_connections) + upload_tasks = uploader.get_substream_blocks() + running_futures = [ + executor.submit(uploader.process_substream_block, u) + for u in islice(upload_tasks, 0, max_connections) + ] + return _parallel_uploads(executor, uploader, upload_tasks, running_futures) + return [uploader.process_substream_block(b) for b in uploader.get_substream_blocks()] + + +class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes + + def __init__(self, service, total_size, chunk_size, stream, parallel, encryptor=None, padder=None, **kwargs): + self.service = service + self.total_size = total_size self.chunk_size = chunk_size self.stream = stream self.parallel = parallel + + # Stream management self.stream_start = stream.tell() if parallel else None self.stream_lock = Lock() if parallel else None + + # Progress feedback self.progress_total = 0 self.progress_lock = Lock() if parallel else None - self.validate_content = validate_content - self.lease_access_conditions = access_conditions - self.modified_access_conditions = None - self.append_conditions = append_conditions - self.timeout = timeout + + # Encryption self.encryptor = encryptor self.padder = padder self.response_headers = None @@ -186,8 +155,8 @@ def get_chunk_streams(self): # Buffer until we either reach the end of the stream or get a whole chunk. while True: - if self.blob_size: - read_size = min(self.chunk_size - len(data), self.blob_size - (index + len(data))) + if self.total_size: + read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data))) temp = self.stream.read(read_size) if not isinstance(temp, six.binary_type): raise TypeError('Blob data should be of type bytes.') @@ -237,7 +206,7 @@ def _upload_chunk_with_progress(self, chunk_offset, chunk_data): def get_substream_blocks(self): assert self.chunk_size is not None lock = self.stream_lock - blob_length = self.blob_size + blob_length = self.total_size if blob_length is None: blob_length = get_length(self.stream) @@ -248,9 +217,9 @@ def get_substream_blocks(self): last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size for i in range(blocks): - yield ('BlockId{}'.format("%05d" % i), - _SubStream(self.stream, i * self.chunk_size, last_block_size if i == blocks - 1 else self.chunk_size, - lock)) + index = i * self.chunk_size + length = last_block_size if i == blocks - 1 else self.chunk_size + yield ('BlockId{}'.format("%05d" % i), SubStream(self.stream, index, length, lock)) def process_substream_block(self, block_data): return self._upload_substream_block_with_progress(block_data[0], block_data[1]) @@ -267,33 +236,27 @@ def set_response_properties(self, resp): self.last_modified = resp.last_modified -class BlockBlobChunkUploader(_BlobChunkUploader): +class BlockBlobChunkUploader(_ChunkUploader): def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. block_id = encode_base64(url_quote(encode_base64('{0:032d}'.format(chunk_offset)))) - self.blob_service.stage_block( + self.service.stage_block( block_id, len(chunk_data), chunk_data, - timeout=self.timeout, - lease_access_conditions=self.lease_access_conditions, - validate_content=self.validate_content, - data_stream_total=self.blob_size, + data_stream_total=self.total_size, upload_stream_current=self.progress_total, **self.request_options) return block_id def _upload_substream_block(self, block_id, block_stream): try: - self.blob_service.stage_block( + self.service.stage_block( block_id, len(block_stream), block_stream, - validate_content=self.validate_content, - lease_access_conditions=self.lease_access_conditions, - timeout=self.timeout, - data_stream_total=self.blob_size, + data_stream_total=self.total_size, upload_stream_current=self.progress_total, **self.request_options) finally: @@ -301,7 +264,7 @@ def _upload_substream_block(self, block_id, block_stream): return block_id -class PageBlobChunkUploader(_BlobChunkUploader): # pylint: disable=abstract-method +class PageBlobChunkUploader(_ChunkUploader): # pylint: disable=abstract-method def _is_chunk_empty(self, chunk_data): # read until non-zero byte is encountered @@ -317,26 +280,21 @@ def _upload_chunk(self, chunk_offset, chunk_data): chunk_end = chunk_offset + len(chunk_data) - 1 content_range = 'bytes={0}-{1}'.format(chunk_offset, chunk_end) computed_md5 = None - self.response_headers = self.blob_service.upload_pages( + self.response_headers = self.service.upload_pages( chunk_data, content_length=len(chunk_data), transactional_content_md5=computed_md5, - timeout=self.timeout, range=content_range, - lease_access_conditions=self.lease_access_conditions, - modified_access_conditions=self.modified_access_conditions, - validate_content=self.validate_content, cls=return_response_headers, - data_stream_total=self.blob_size, + data_stream_total=self.total_size, upload_stream_current=self.progress_total, **self.request_options) - if not self.parallel: - self.modified_access_conditions = ModifiedAccessConditions( - if_match=self.response_headers['etag']) + if not self.parallel and self.request_options.get('modified_access_conditions'): + self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] -class AppendBlobChunkUploader(_BlobChunkUploader): # pylint: disable=abstract-method +class AppendBlobChunkUploader(_ChunkUploader): # pylint: disable=abstract-method def __init__(self, *args, **kwargs): super(AppendBlobChunkUploader, self).__init__(*args, **kwargs) @@ -344,126 +302,45 @@ def __init__(self, *args, **kwargs): def _upload_chunk(self, chunk_offset, chunk_data): if self.current_length is None: - self.response_headers = self.blob_service.append_block( + self.response_headers = self.service.append_block( chunk_data, content_length=len(chunk_data), - timeout=self.timeout, - lease_access_conditions=self.lease_access_conditions, - modified_access_conditions=self.modified_access_conditions, - validate_content=self.validate_content, - append_position_access_conditions=self.append_conditions, cls=return_response_headers, - data_stream_total=self.blob_size, + data_stream_total=self.total_size, upload_stream_current=self.progress_total, **self.request_options ) self.current_length = int(self.response_headers['blob_append_offset']) else: - self.append_conditions.append_position = self.current_length + chunk_offset - self.response_headers = self.blob_service.append_block( + self.request_options['append_position_access_conditions'].append_position = \ + self.current_length + chunk_offset + self.response_headers = self.service.append_block( chunk_data, content_length=len(chunk_data), - timeout=self.timeout, - lease_access_conditions=self.lease_access_conditions, - modified_access_conditions=self.modified_access_conditions, - validate_content=self.validate_content, - append_position_access_conditions=self.append_conditions, cls=return_response_headers, - data_stream_total=self.blob_size, + data_stream_total=self.total_size, upload_stream_current=self.progress_total, **self.request_options ) -class FileChunkUploader(object): # pylint: disable=too-many-instance-attributes - - def __init__(self, file_service, file_size, chunk_size, stream, parallel, - validate_content, timeout, **kwargs): - self.file_service = file_service - self.file_size = file_size - self.chunk_size = chunk_size - self.stream = stream - self.parallel = parallel - self.stream_start = stream.tell() if parallel else None - self.stream_lock = Lock() if parallel else None - self.progress_total = 0 - self.progress_lock = Lock() if parallel else None - self.validate_content = validate_content - self.timeout = timeout - self.request_options = kwargs - - def get_chunk_offsets(self): - index = 0 - if self.file_size is None: - # we don't know the size of the stream, so we have no - # choice but to seek - while True: - data = self._read_from_stream(index, 1) - if not data: - break - yield index - index += self.chunk_size - else: - while index < self.file_size: - yield index - index += self.chunk_size - - def process_chunk(self, chunk_offset): - size = self.chunk_size - if self.file_size is not None: - size = min(size, self.file_size - chunk_offset) - chunk_data = self._read_from_stream(chunk_offset, size) - return self._upload_chunk_with_progress(chunk_offset, chunk_data) - - def process_all_unknown_size(self): - assert self.stream_lock is None - range_ids = [] - index = 0 - while True: - data = self._read_from_stream(None, self.chunk_size) - if data: - index += len(data) - range_id = self._upload_chunk_with_progress(index, data) - range_ids.append(range_id) - else: - break - - return range_ids - - def _read_from_stream(self, offset, count): - if self.stream_lock is not None: - with self.stream_lock: - self.stream.seek(self.stream_start + offset) - data = self.stream.read(count) - else: - data = self.stream.read(count) - return data +class FileChunkUploader(_ChunkUploader): # pylint: disable=abstract-method - def _update_progress(self, length): - if self.progress_lock is not None: - with self.progress_lock: - self.progress_total += length - else: - self.progress_total += length - - def _upload_chunk_with_progress(self, chunk_start, chunk_data): - chunk_end = chunk_start + len(chunk_data) - 1 - self.file_service.upload_range( + def _upload_chunk(self, chunk_offset, chunk_data): + chunk_end = chunk_offset + len(chunk_data) - 1 + self.service.upload_range( chunk_data, - chunk_start, + chunk_offset, chunk_end, - validate_content=self.validate_content, - timeout=self.timeout, - data_stream_total=self.file_size, + data_stream_total=self.total_size, upload_stream_current=self.progress_total, **self.request_options ) - range_id = 'bytes={0}-{1}'.format(chunk_start, chunk_end) - self._update_progress(len(chunk_data)) - return range_id + return 'bytes={0}-{1}'.format(chunk_offset, chunk_end) + +class SubStream(IOBase): -class _SubStream(IOBase): def __init__(self, wrapped_stream, stream_begin_index, length, lockObj): # Python 2.7: file-like objects created with open() typically support seek(), but are not # derivations of io.IOBase and thus do not implement seekable(). @@ -487,7 +364,7 @@ def __init__(self, wrapped_stream, stream_begin_index, length, lockObj): else _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE self._current_buffer_start = 0 self._current_buffer_size = 0 - super(_SubStream, self).__init__() + super(SubStream, self).__init__() def __len__(self): return self._length diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads_async.py new file mode 100644 index 000000000000..984a6bf6588b --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads_async.py @@ -0,0 +1,342 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=no-self-use + +import asyncio +from asyncio import Lock +from itertools import islice + +from math import ceil + +import six + +from . import encode_base64, url_quote +from .request_handlers import get_length +from .response_handlers import return_response_headers +from .encryption import get_blob_encryptor_and_padder +from .uploads import SubStream, IterStreamer # pylint: disable=unused-import + + +_LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE = 4 * 1024 * 1024 +_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM = '{0} should be a seekable file-like/io.IOBase type stream object.' + + +async def _parallel_uploads(uploader, pending, running): + range_ids = [] + while True: + # Wait for some download to finish before adding a new one + done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED) + range_ids.extend([chunk.result() for chunk in done]) + try: + next_chunk = next(pending) + except StopIteration: + break + else: + running.add(asyncio.ensure_future(uploader.process_chunk(next_chunk))) + + # Wait for the remaining uploads to finish + if running: + done, _running = await asyncio.wait(running) + range_ids.extend([chunk.result() for chunk in done]) + return range_ids + + +async def upload_data_chunks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_connections=None, + stream=None, + encryption_options=None, + **kwargs): + + if encryption_options: + encryptor, padder = get_blob_encryptor_and_padder( + encryption_options.get('key'), + encryption_options.get('vector'), + uploader_class is not PageBlobChunkUploader) + kwargs['encryptor'] = encryptor + kwargs['padder'] = padder + + parallel = max_connections > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None + + uploader = uploader_class( + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + **kwargs) + + if parallel: + upload_tasks = uploader.get_chunk_streams() + running_futures = [ + asyncio.ensure_future(uploader.process_chunk(u)) + for u in islice(upload_tasks, 0, max_connections) + ] + range_ids = await _parallel_uploads(uploader, upload_tasks, running_futures) + else: + range_ids = [] + for chunk in uploader.get_chunk_streams(): + range_ids.append(await uploader.process_chunk(chunk)) + + if any(range_ids): + return range_ids + return uploader.response_headers + + +async def upload_substream_blocks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_connections=None, + stream=None, + **kwargs): + parallel = max_connections > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None + uploader = uploader_class( + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + **kwargs) + + if parallel: + upload_tasks = uploader.get_substream_blocks() + running_futures = [ + asyncio.ensure_future(uploader.process_substream_block(u)) + for u in islice(upload_tasks, 0, max_connections) + ] + return await _parallel_uploads(uploader, upload_tasks, running_futures) + blocks = [] + for block in uploader.get_substream_blocks(): + blocks.append(await uploader.process_substream_block(block)) + return blocks + + +class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes + + def __init__(self, service, total_size, chunk_size, stream, parallel, encryptor=None, padder=None, **kwargs): + self.service = service + self.total_size = total_size + self.chunk_size = chunk_size + self.stream = stream + self.parallel = parallel + + # Stream management + self.stream_start = stream.tell() if parallel else None + self.stream_lock = Lock() if parallel else None + + # Progress feedback + self.progress_total = 0 + self.progress_lock = Lock() if parallel else None + + # Encryption + self.encryptor = encryptor + self.padder = padder + self.response_headers = None + self.etag = None + self.last_modified = None + self.request_options = kwargs + + def get_chunk_streams(self): + index = 0 + while True: + data = b'' + read_size = self.chunk_size + + # Buffer until we either reach the end of the stream or get a whole chunk. + while True: + if self.total_size: + read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data))) + temp = self.stream.read(read_size) + if not isinstance(temp, six.binary_type): + raise TypeError('Blob data should be of type bytes.') + data += temp or b"" + + # We have read an empty string and so are at the end + # of the buffer or we have read a full chunk. + if temp == b'' or len(data) == self.chunk_size: + break + + if len(data) == self.chunk_size: + if self.padder: + data = self.padder.update(data) + if self.encryptor: + data = self.encryptor.update(data) + yield index, data + else: + if self.padder: + data = self.padder.update(data) + self.padder.finalize() + if self.encryptor: + data = self.encryptor.update(data) + self.encryptor.finalize() + if data: + yield index, data + break + index += len(data) + + async def process_chunk(self, chunk_data): + chunk_bytes = chunk_data[1] + chunk_offset = chunk_data[0] + return await self._upload_chunk_with_progress(chunk_offset, chunk_bytes) + + async def _update_progress(self, length): + if self.progress_lock is not None: + async with self.progress_lock: + self.progress_total += length + else: + self.progress_total += length + + async def _upload_chunk(self, chunk_offset, chunk_data): + raise NotImplementedError("Must be implemented by child class.") + + async def _upload_chunk_with_progress(self, chunk_offset, chunk_data): + range_id = await self._upload_chunk(chunk_offset, chunk_data) + await self._update_progress(len(chunk_data)) + return range_id + + def get_substream_blocks(self): + assert self.chunk_size is not None + lock = self.stream_lock + blob_length = self.total_size + + if blob_length is None: + blob_length = get_length(self.stream) + if blob_length is None: + raise ValueError("Unable to determine content length of upload data.") + + blocks = int(ceil(blob_length / (self.chunk_size * 1.0))) + last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size + + for i in range(blocks): + index = i * self.chunk_size + length = last_block_size if i == blocks - 1 else self.chunk_size + yield ('BlockId{}'.format("%05d" % i), SubStream(self.stream, index, length, lock)) + + async def process_substream_block(self, block_data): + return await self._upload_substream_block_with_progress(block_data[0], block_data[1]) + + async def _upload_substream_block(self, block_id, block_stream): + raise NotImplementedError("Must be implemented by child class.") + + async def _upload_substream_block_with_progress(self, block_id, block_stream): + range_id = self._upload_substream_block(block_id, block_stream) + await self._update_progress(len(block_stream)) + return range_id + + def set_response_properties(self, resp): + self.etag = resp.etag + self.last_modified = resp.last_modified + + +class BlockBlobChunkUploader(_ChunkUploader): + + async def _upload_chunk(self, chunk_offset, chunk_data): + # TODO: This is incorrect, but works with recording. + block_id = encode_base64(url_quote(encode_base64('{0:032d}'.format(chunk_offset)))) + await self.service.stage_block( + block_id, + len(chunk_data), + chunk_data, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + return block_id + + async def _upload_substream_block(self, block_id, block_stream): + try: + await self.service.stage_block( + block_id, + len(block_stream), + block_stream, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + finally: + block_stream.close() + return block_id + + +class PageBlobChunkUploader(_ChunkUploader): # pylint: disable=abstract-method + + def _is_chunk_empty(self, chunk_data): + # read until non-zero byte is encountered + # if reached the end without returning, then chunk_data is all 0's + for each_byte in chunk_data: + if each_byte not in [0, b'\x00']: + return False + return True + + async def _upload_chunk(self, chunk_offset, chunk_data): + # avoid uploading the empty pages + if not self._is_chunk_empty(chunk_data): + chunk_end = chunk_offset + len(chunk_data) - 1 + content_range = 'bytes={0}-{1}'.format(chunk_offset, chunk_end) + computed_md5 = None + self.response_headers = await self.service.upload_pages( + chunk_data, + content_length=len(chunk_data), + transactional_content_md5=computed_md5, + range=content_range, + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + + if not self.parallel and self.request_options.get('modified_access_conditions'): + self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + + +class AppendBlobChunkUploader(_ChunkUploader): # pylint: disable=abstract-method + + def __init__(self, *args, **kwargs): + super(AppendBlobChunkUploader, self).__init__(*args, **kwargs) + self.current_length = None + + async def _upload_chunk(self, chunk_offset, chunk_data): + if self.current_length is None: + self.response_headers = await self.service.append_block( + chunk_data, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + self.current_length = int(self.response_headers['blob_append_offset']) + else: + self.request_options['append_position_access_conditions'].append_position = \ + self.current_length + chunk_offset + self.response_headers = await self.service.append_block( + chunk_data, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + + +class FileChunkUploader(_ChunkUploader): # pylint: disable=abstract-method + + async def _upload_chunk(self, chunk_offset, chunk_data): + chunk_end = chunk_offset + len(chunk_data) - 1 + await self.service.upload_range( + chunk_data, + chunk_offset, + chunk_end, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + range_id = 'bytes={0}-{1}'.format(chunk_offset, chunk_end) + return range_id diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/__init__.py new file mode 100644 index 000000000000..c70ebec1f8cc --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/__init__.py @@ -0,0 +1,30 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from azure.storage.queue.version import VERSION +from .queue_client_async import QueueClient +from .queue_service_client_async import QueueServiceClient +from .models import MessagesPaged, QueuePropertiesPaged +from ..models import ( + Logging, Metrics, RetentionPolicy, CorsRule, AccessPolicy, + QueueMessage, QueuePermissions, QueueProperties) + +__version__ = VERSION + +__all__ = [ + 'QueueClient', + 'QueueServiceClient', + 'Logging', + 'Metrics', + 'RetentionPolicy', + 'CorsRule', + 'AccessPolicy', + 'QueueMessage', + 'MessagesPaged', + 'QueuePermissions', + 'QueueProperties', + 'QueuePropertiesPaged' +] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/models.py new file mode 100644 index 000000000000..047f34204085 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/models.py @@ -0,0 +1,108 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=too-few-public-methods, too-many-instance-attributes +# pylint: disable=super-init-not-called + +from typing import List # pylint: disable=unused-import +from azure.core.paging import Paged +from .._shared.response_handlers import ( + process_storage_error, + return_context_and_deserialized) +from .._generated.models import StorageErrorException +from ..models import QueueMessage, QueueProperties + + +class MessagesPaged(Paged): + """An iterable of Queue Messages. + + :ivar int results_per_page: The maximum number of results retrieved per API call. + :ivar current_page: The current page of listed results. + :vartype current_page: list(~azure.storage.queue.models.QueueMessage) + + :param callable command: Function to retrieve the next page of items. + :param int results_per_page: The maximum number of messages to retrieve per + call. + """ + def __init__(self, command, results_per_page=None): + super(MessagesPaged, self).__init__(None, None, async_command=command) + self.results_per_page = results_per_page + + async def _async_advance_page(self): + """Force moving the cursor to the next azure call. + + This method is for advanced usage, iterator protocol is prefered. + + :raises: StopIteration if no further page + :return: The current page list + :rtype: list + """ + self._current_page_iter_index = 0 + try: + messages = await self._async_get_next(number_of_messages=self.results_per_page) + if not messages: + raise StopAsyncIteration() + except StorageErrorException as error: + process_storage_error(error) + self.current_page = [QueueMessage._from_generated(q) for q in messages] # pylint: disable=protected-access + return self.current_page + + +class QueuePropertiesPaged(Paged): + """An iterable of Queue properties. + + :ivar str service_endpoint: The service URL. + :ivar str prefix: A queue name prefix being used to filter the list. + :ivar str current_marker: The continuation token of the current page of results. + :ivar int results_per_page: The maximum number of results retrieved per API call. + :ivar str next_marker: The continuation token to retrieve the next page of results. + :ivar str location_mode: The location mode being used to list results. The available + options include "primary" and "secondary". + :ivar current_page: The current page of listed results. + :vartype current_page: list(~azure.storage.queue.models.QueueProperties) + + :param callable command: Function to retrieve the next page of items. + :param str prefix: Filters the results to return only queues whose names + begin with the specified prefix. + :param int results_per_page: The maximum number of queue names to retrieve per + call. + :param str marker: An opaque continuation token. + """ + def __init__(self, command, prefix=None, results_per_page=None, marker=None): + super(QueuePropertiesPaged, self).__init__(None, None, async_command=command) + self.service_endpoint = None + self.prefix = prefix + self.current_marker = None + self.results_per_page = results_per_page + self.next_marker = marker or "" + self.location_mode = None + + async def _async_advance_page(self): + """Force moving the cursor to the next azure call. + + This method is for advanced usage, iterator protocol is prefered. + + :raises: StopIteration if no further page + :return: The current page list + :rtype: list + """ + if self.next_marker is None: + raise StopAsyncIteration("End of paging") + self._current_page_iter_index = 0 + try: + self.location_mode, self._response = await self._async_get_next( + marker=self.next_marker or None, + maxresults=self.results_per_page, + cls=return_context_and_deserialized, + use_location=self.location_mode) + except StorageErrorException as error: + process_storage_error(error) + self.service_endpoint = self._response.service_endpoint + self.prefix = self._response.prefix + self.current_marker = self._response.marker + self.results_per_page = self._response.max_results + self.current_page = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + self.next_marker = self._response.next_marker or None + return self.current_page diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/queue_client_async.py new file mode 100644 index 000000000000..d1757d12bb1f --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/queue_client_async.py @@ -0,0 +1,644 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import functools +from typing import ( # pylint: disable=unused-import + Union, Optional, Any, IO, Iterable, AnyStr, Dict, List, Tuple, + TYPE_CHECKING) +try: + from urllib.parse import urlparse, quote, unquote # pylint: disable=unused-import +except ImportError: + from urlparse import urlparse # type: ignore + from urllib2 import quote, unquote # type: ignore + +from azure.storage.queue._shared.base_client_async import AsyncStorageAccountHostsMixin +from azure.storage.queue._shared.request_handlers import add_metadata_headers, serialize_iso +from azure.storage.queue._shared.response_handlers import ( + return_response_headers, + process_storage_error, + return_headers_and_deserialized) +from azure.storage.queue._deserialize import ( + deserialize_queue_properties, + deserialize_queue_creation) +from azure.storage.queue._generated.aio import AzureQueueStorage +from azure.storage.queue._generated.models import StorageErrorException, SignedIdentifier +from azure.storage.queue._generated.models import QueueMessage as GenQueueMessage + +from azure.storage.queue.models import QueueMessage, AccessPolicy +from azure.storage.queue.aio.models import MessagesPaged +from .._shared.policies_async import ExponentialRetry +from ..queue_client import QueueClient as QueueClientBase + + +if TYPE_CHECKING: + from datetime import datetime + from azure.core.pipeline.policies import HTTPPolicy + from azure.storage.queue.models import QueuePermissions, QueueProperties + + +class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase): + """A async client to interact with a specific Queue. + + :ivar str url: + The full endpoint URL to the Queue, including SAS token if used. This could be + either the primary endpoint, or the secondard endpint depending on the current `location_mode`. + :ivar str primary_endpoint: + The full primary endpoint URL. + :ivar str primary_hostname: + The hostname of the primary endpoint. + :ivar str secondary_endpoint: + The full secondard endpoint URL if configured. If not available + a ValueError will be raised. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + :ivar str secondary_hostname: + The hostname of the secondary endpoint. If not available this + will be None. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + :ivar str location_mode: + The location mode that the client is currently using. By default + this will be "primary". Options include "primary" and "secondary". + :param str queue_url: The full URI to the queue. This can also be a URL to the storage + account, in which case the queue must also be specified. + :param queue: The queue. If specified, this value will override + a queue value specified in the queue URL. + :type queue: str or ~azure.storage.queue.models.QueueProperties + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, and account + shared access key, or an instance of a TokenCredentials class from azure.identity. + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START create_queue_client] + :end-before: [END create_queue_client] + :language: python + :dedent: 12 + :caption: Create the queue client with url and credential. + """ + def __init__( + self, queue_url, # type: str + queue=None, # type: Optional[Union[QueueProperties, str]] + credential=None, # type: Optional[Any] + loop=None, # type: Any + **kwargs # type: Any + ): + # type: (...) -> None + kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) + super(QueueClient, self).__init__( + queue_url, + queue=queue, + credential=credential, + loop=loop, + **kwargs) + self._client = AzureQueueStorage(self.url, pipeline=self._pipeline, loop=loop) # type: ignore + self._loop = loop + + async def create_queue(self, metadata=None, timeout=None, **kwargs): # type: ignore + # type: (Optional[Dict[str, Any]], Optional[int], Optional[Any]) -> None + """Creates a new queue in the storage account. + + If a queue with the same name already exists, the operation fails. + + :param metadata: + A dict containing name-value pairs to associate with the queue as + metadata. Note that metadata names preserve the case with which they + were created, but are case-insensitive when set or read. + :type metadata: dict(str, str) + :param int timeout: + The server timeout, expressed in seconds. + :return: None or the result of cls(response) + :rtype: None + :raises: + ~azure.storage.queue._generated.models._models.StorageErrorException + + Example: + .. literalinclude:: ../tests/test_queue_samples_hello_world.py + :start-after: [START create_queue] + :end-before: [END create_queue] + :language: python + :dedent: 8 + :caption: Create a queue. + """ + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) # type: ignore + try: + return await self._client.queue.create( # type: ignore + metadata=metadata, + timeout=timeout, + headers=headers, + cls=deserialize_queue_creation, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + + async def delete_queue(self, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[Any]) -> None + """Deletes the specified queue and any messages it contains. + + When a queue is successfully deleted, it is immediately marked for deletion + and is no longer accessible to clients. The queue is later removed from + the Queue service during garbage collection. + + Note that deleting a queue is likely to take at least 40 seconds to complete. + If an operation is attempted against the queue while it was being deleted, + an :class:`HttpResponseError` will be thrown. + + :param int timeout: + The server timeout, expressed in seconds. + :rtype: None + + Example: + .. literalinclude:: ../tests/test_queue_samples_hello_world.py + :start-after: [START delete_queue] + :end-before: [END delete_queue] + :language: python + :dedent: 12 + :caption: Delete a queue. + """ + try: + await self._client.queue.delete(timeout=timeout, **kwargs) + except StorageErrorException as error: + process_storage_error(error) + + async def get_queue_properties(self, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[Any]) -> QueueProperties + """Returns all user-defined metadata for the specified queue. + + The data returned does not include the queue's list of messages. + + :param int timeout: + The timeout parameter is expressed in seconds. + :return: Properties for the specified container within a container object. + :rtype: ~azure.storage.queue.models.QueueProperties + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START get_queue_properties] + :end-before: [END get_queue_properties] + :language: python + :dedent: 12 + :caption: Get the properties on the queue. + """ + try: + response = await self._client.queue.get_properties( + timeout=timeout, + cls=deserialize_queue_properties, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + response.name = self.queue_name + return response # type: ignore + + async def set_queue_metadata(self, metadata=None, timeout=None, **kwargs): # type: ignore + # type: (Optional[Dict[str, Any]], Optional[int], Optional[Any]) -> None + """Sets user-defined metadata on the specified queue. + + Metadata is associated with the queue as name-value pairs. + + :param metadata: + A dict containing name-value pairs to associate with the + queue as metadata. + :type metadata: dict(str, str) + :param int timeout: + The server timeout, expressed in seconds. + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START set_queue_metadata] + :end-before: [END set_queue_metadata] + :language: python + :dedent: 12 + :caption: Set metadata on the queue. + """ + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) # type: ignore + try: + return (await self._client.queue.set_metadata( # type: ignore + timeout=timeout, + headers=headers, + cls=return_response_headers, + **kwargs)) + except StorageErrorException as error: + process_storage_error(error) + + async def get_queue_access_policy(self, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[Any]) -> Dict[str, Any] + """Returns details about any stored access policies specified on the + queue that may be used with Shared Access Signatures. + + :param int timeout: + The server timeout, expressed in seconds. + :return: A dictionary of access policies associated with the queue. + :rtype: dict(str, :class:`~azure.storage.queue.models.AccessPolicy`) + """ + try: + _, identifiers = await self._client.queue.get_access_policy( + timeout=timeout, + cls=return_headers_and_deserialized, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + return {s.id: s.access_policy or AccessPolicy() for s in identifiers} + + async def set_queue_access_policy(self, signed_identifiers=None, timeout=None, **kwargs): # type: ignore + # type: (Optional[Dict[str, Optional[AccessPolicy]]], Optional[int], Optional[Any]) -> None + """Sets stored access policies for the queue that may be used with Shared + Access Signatures. + + When you set permissions for a queue, the existing permissions are replaced. + To update the queue's permissions, call :func:`~get_queue_access_policy` to fetch + all access policies associated with the queue, modify the access policy + that you wish to change, and then call this function with the complete + set of data to perform the update. + + When you establish a stored access policy on a queue, it may take up to + 30 seconds to take effect. During this interval, a shared access signature + that is associated with the stored access policy will throw an + :class:`HttpResponseError` until the access policy becomes active. + + :param signed_identifiers: + A list of SignedIdentifier access policies to associate with the queue. + The list may contain up to 5 elements. An empty list + will clear the access policies set on the service. + :type signed_identifiers: dict(str, :class:`~azure.storage.queue.models.AccessPolicy`) + :param int timeout: + The server timeout, expressed in seconds. + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START set_access_policy] + :end-before: [END set_access_policy] + :language: python + :dedent: 12 + :caption: Set an access policy on the queue. + """ + if signed_identifiers: + if len(signed_identifiers) > 15: + raise ValueError( + 'Too many access policies provided. The server does not support setting ' + 'more than 15 access policies on a single resource.') + identifiers = [] + for key, value in signed_identifiers.items(): + if value: + value.start = serialize_iso(value.start) + value.expiry = serialize_iso(value.expiry) + identifiers.append(SignedIdentifier(id=key, access_policy=value)) + signed_identifiers = identifiers # type: ignore + try: + await self._client.queue.set_access_policy( + queue_acl=signed_identifiers or None, + timeout=timeout, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + + async def enqueue_message( # type: ignore + self, content, # type: Any + visibility_timeout=None, # type: Optional[int] + time_to_live=None, # type: Optional[int] + timeout=None, # type: Optional[int] + **kwargs # type: Optional[Any] + ): + # type: (...) -> QueueMessage + """Adds a new message to the back of the message queue. + + The visibility timeout specifies the time that the message will be + invisible. After the timeout expires, the message will become visible. + If a visibility timeout is not specified, the default value of 0 is used. + + The message time-to-live specifies how long a message will remain in the + queue. The message will be deleted from the queue when the time-to-live + period expires. + + If the key-encryption-key field is set on the local service object, this method will + encrypt the content before uploading. + + :param obj content: + Message content. Allowed type is determined by the encode_function + set on the service. Default is str. The encoded message can be up to + 64KB in size. + :param int visibility_timeout: + If not specified, the default value is 0. Specifies the + new visibility timeout value, in seconds, relative to server time. + The value must be larger than or equal to 0, and cannot be + larger than 7 days. The visibility timeout of a message cannot be + set to a value later than the expiry time. visibility_timeout + should be set to a value smaller than the time-to-live value. + :param int time_to_live: + Specifies the time-to-live interval for the message, in + seconds. The time-to-live may be any positive number or -1 for infinity. If this + parameter is omitted, the default time-to-live is 7 days. + :param int timeout: + The server timeout, expressed in seconds. + :return: + A :class:`~azure.storage.queue.models.QueueMessage` object. + This object is also populated with the content although it is not + returned from the service. + :rtype: ~azure.storage.queue.models.QueueMessage + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START enqueue_messages] + :end-before: [END enqueue_messages] + :language: python + :dedent: 12 + :caption: Enqueue messages. + """ + self._config.message_encode_policy.configure( + self.require_encryption, + self.key_encryption_key, + self.key_resolver_function) + content = self._config.message_encode_policy(content) + new_message = GenQueueMessage(message_text=content) + + try: + enqueued = await self._client.messages.enqueue( + queue_message=new_message, + visibilitytimeout=visibility_timeout, + message_time_to_live=time_to_live, + timeout=timeout, + **kwargs) + queue_message = QueueMessage(content=new_message.message_text) + queue_message.id = enqueued[0].message_id + queue_message.insertion_time = enqueued[0].insertion_time + queue_message.expiration_time = enqueued[0].expiration_time + queue_message.pop_receipt = enqueued[0].pop_receipt + queue_message.time_next_visible = enqueued[0].time_next_visible + return queue_message + except StorageErrorException as error: + process_storage_error(error) + + def receive_messages(self, messages_per_page=None, visibility_timeout=None, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[int], Optional[int], Optional[Any]) -> QueueMessage + """Removes one or more messages from the front of the queue. + + When a message is retrieved from the queue, the response includes the message + content and a pop_receipt value, which is required to delete the message. + The message is not automatically deleted from the queue, but after it has + been retrieved, it is not visible to other clients for the time interval + specified by the visibility_timeout parameter. + + If the key-encryption-key or resolver field is set on the local service object, the messages will be + decrypted before being returned. + + :param int messages_per_page: + A nonzero integer value that specifies the number of + messages to retrieve from the queue, up to a maximum of 32. If + fewer are visible, the visible messages are returned. By default, + a single message is retrieved from the queue with this operation. + :param int visibility_timeout: + If not specified, the default value is 0. Specifies the + new visibility timeout value, in seconds, relative to server time. + The value must be larger than or equal to 0, and cannot be + larger than 7 days. The visibility timeout of a message cannot be + set to a value later than the expiry time. visibility_timeout + should be set to a value smaller than the time-to-live value. + :param int timeout: + The server timeout, expressed in seconds. + :return: + Returns a message iterator of dict-like Message objects. + :rtype: ~azure.storage.queue.models.MessagesPaged + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START receive_messages] + :end-before: [END receive_messages] + :language: python + :dedent: 12 + :caption: Receive messages from the queue. + """ + self._config.message_decode_policy.configure( + self.require_encryption, + self.key_encryption_key, + self.key_resolver_function) + try: + command = functools.partial( + self._client.messages.dequeue, + visibilitytimeout=visibility_timeout, + timeout=timeout, + cls=self._config.message_decode_policy, + **kwargs + ) + return MessagesPaged(command, results_per_page=messages_per_page) + except StorageErrorException as error: + process_storage_error(error) + + async def update_message(self, message, visibility_timeout=None, pop_receipt=None, # type: ignore + content=None, timeout=None, **kwargs): + # type: (Any, int, Optional[str], Optional[Any], Optional[int], Any) -> QueueMessage + """Updates the visibility timeout of a message. You can also use this + operation to update the contents of a message. + + This operation can be used to continually extend the invisibility of a + queue message. This functionality can be useful if you want a worker role + to "lease" a queue message. For example, if a worker role calls :func:`~receive_messages()` + and recognizes that it needs more time to process a message, it can + continually extend the message's invisibility until it is processed. If + the worker role were to fail during processing, eventually the message + would become visible again and another worker role could process it. + + If the key-encryption-key field is set on the local service object, this method will + encrypt the content before uploading. + + :param str message: + The message object or id identifying the message to update. + :param int visibility_timeout: + Specifies the new visibility timeout value, in seconds, + relative to server time. The new value must be larger than or equal + to 0, and cannot be larger than 7 days. The visibility timeout of a + message cannot be set to a value later than the expiry time. A + message can be updated until it has been deleted or has expired. + The message object or message id identifying the message to update. + :param str pop_receipt: + A valid pop receipt value returned from an earlier call + to the :func:`~receive_messages` or :func:`~update_message` operation. + :param obj content: + Message content. Allowed type is determined by the encode_function + set on the service. Default is str. + :param int timeout: + The server timeout, expressed in seconds. + :return: + A :class:`~azure.storage.queue.models.QueueMessage` object. For convenience, + this object is also populated with the content, although it is not returned by the service. + :rtype: ~azure.storage.queue.models.QueueMessage + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START update_message] + :end-before: [END update_message] + :language: python + :dedent: 12 + :caption: Update a message. + """ + try: + message_id = message.id + message_text = content or message.content + receipt = pop_receipt or message.pop_receipt + insertion_time = message.insertion_time + expiration_time = message.expiration_time + dequeue_count = message.dequeue_count + except AttributeError: + message_id = message + message_text = content + receipt = pop_receipt + insertion_time = None + expiration_time = None + dequeue_count = None + + if receipt is None: + raise ValueError("pop_receipt must be present") + if message_text is not None: + self._config.message_encode_policy.configure( + self.require_encryption, + self.key_encryption_key, + self.key_resolver_function) + message_text = self._config.message_encode_policy(message_text) + updated = GenQueueMessage(message_text=message_text) + else: + updated = None # type: ignore + try: + response = await self._client.message_id.update( + queue_message=updated, + visibilitytimeout=visibility_timeout or 0, + timeout=timeout, + pop_receipt=receipt, + cls=return_response_headers, + queue_message_id=message_id, + **kwargs) + new_message = QueueMessage(content=message_text) + new_message.id = message_id + new_message.insertion_time = insertion_time + new_message.expiration_time = expiration_time + new_message.dequeue_count = dequeue_count + new_message.pop_receipt = response['popreceipt'] + new_message.time_next_visible = response['time_next_visible'] + return new_message + except StorageErrorException as error: + process_storage_error(error) + + async def peek_messages(self, max_messages=None, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[int], Optional[Any]) -> List[QueueMessage] + """Retrieves one or more messages from the front of the queue, but does + not alter the visibility of the message. + + Only messages that are visible may be retrieved. When a message is retrieved + for the first time with a call to :func:`~receive_messages`, its dequeue_count property + is set to 1. If it is not deleted and is subsequently retrieved again, the + dequeue_count property is incremented. The client may use this value to + determine how many times a message has been retrieved. Note that a call + to peek_messages does not increment the value of DequeueCount, but returns + this value for the client to read. + + If the key-encryption-key or resolver field is set on the local service object, + the messages will be decrypted before being returned. + + :param int max_messages: + A nonzero integer value that specifies the number of + messages to peek from the queue, up to a maximum of 32. By default, + a single message is peeked from the queue with this operation. + :param int timeout: + The server timeout, expressed in seconds. + :return: + A list of :class:`~azure.storage.queue.models.QueueMessage` objects. Note that + time_next_visible and pop_receipt will not be populated as peek does + not pop the message and can only retrieve already visible messages. + :rtype: list(:class:`~azure.storage.queue.models.QueueMessage`) + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START peek_message] + :end-before: [END peek_message] + :language: python + :dedent: 12 + :caption: Peek messages. + """ + if max_messages and not 1 <= max_messages <= 32: + raise ValueError("Number of messages to peek should be between 1 and 32") + self._config.message_decode_policy.configure( + self.require_encryption, + self.key_encryption_key, + self.key_resolver_function) + try: + messages = await self._client.messages.peek( + number_of_messages=max_messages, + timeout=timeout, + cls=self._config.message_decode_policy, + **kwargs) + wrapped_messages = [] + for peeked in messages: + wrapped_messages.append(QueueMessage._from_generated(peeked)) # pylint: disable=protected-access + return wrapped_messages + except StorageErrorException as error: + process_storage_error(error) + + async def clear_messages(self, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[Any]) -> None + """Deletes all messages from the specified queue. + + :param int timeout: + The server timeout, expressed in seconds. + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START clear_messages] + :end-before: [END clear_messages] + :language: python + :dedent: 12 + :caption: Clears all messages. + """ + try: + await self._client.messages.clear(timeout=timeout, **kwargs) + except StorageErrorException as error: + process_storage_error(error) + + async def delete_message(self, message, pop_receipt=None, timeout=None, **kwargs): # type: ignore + # type: (Any, Optional[str], Optional[str], Optional[int]) -> None + """Deletes the specified message. + + Normally after a client retrieves a message with the receive messages operation, + the client is expected to process and delete the message. To delete the + message, you must have the message object itself, or two items of data: id and pop_receipt. + The id is returned from the previous receive_messages operation. The + pop_receipt is returned from the most recent :func:`~receive_messages` or + :func:`~update_message` operation. In order for the delete_message operation + to succeed, the pop_receipt specified on the request must match the + pop_receipt returned from the :func:`~receive_messages` or :func:`~update_message` + operation. + + :param str message: + The message object or id identifying the message to delete. + :param str pop_receipt: + A valid pop receipt value returned from an earlier call + to the :func:`~receive_messages` or :func:`~update_message`. + :param int timeout: + The server timeout, expressed in seconds. + + Example: + .. literalinclude:: ../tests/test_queue_samples_message.py + :start-after: [START delete_message] + :end-before: [END delete_message] + :language: python + :dedent: 12 + :caption: Delete a message. + """ + try: + message_id = message.id + receipt = pop_receipt or message.pop_receipt + except AttributeError: + message_id = message + receipt = pop_receipt + + if receipt is None: + raise ValueError("pop_receipt must be present") + try: + await self._client.message_id.delete( + pop_receipt=receipt, + timeout=timeout, + queue_message_id=message_id, + **kwargs + ) + except StorageErrorException as error: + process_storage_error(error) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/queue_service_client_async.py new file mode 100644 index 000000000000..47d767069933 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/queue_service_client_async.py @@ -0,0 +1,350 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import functools +from typing import ( # pylint: disable=unused-import + Union, Optional, Any, Iterable, Dict, List, + TYPE_CHECKING) +try: + from urllib.parse import urlparse # pylint: disable=unused-import +except ImportError: + from urlparse import urlparse # type: ignore + +from azure.storage.queue._shared.policies_async import ExponentialRetry +from azure.storage.queue.queue_service_client import QueueServiceClient as QueueServiceClientBase +from azure.storage.queue._shared.models import LocationMode +from azure.storage.queue._shared.base_client_async import AsyncStorageAccountHostsMixin +from azure.storage.queue._shared.response_handlers import process_storage_error +from azure.storage.queue._generated.aio import AzureQueueStorage +from azure.storage.queue._generated.models import StorageServiceProperties, StorageErrorException + +from azure.storage.queue.aio.models import QueuePropertiesPaged +from .queue_client_async import QueueClient + +if TYPE_CHECKING: + from datetime import datetime + from azure.core import Configuration + from azure.core.pipeline.policies import HTTPPolicy + from azure.storage.queue._shared.models import AccountPermissions, ResourceTypes + from azure.storage.queue.aio.models import ( + QueueProperties + ) + from azure.storage.queue.models import Logging, Metrics, CorsRule + + +class QueueServiceClient(AsyncStorageAccountHostsMixin, QueueServiceClientBase): + """A client to interact with the Queue Service at the account level. + + This client provides operations to retrieve and configure the account properties + as well as list, create and delete queues within the account. + For operations relating to a specific queue, a client for this entity + can be retrieved using the :func:`~get_queue_client` function. + + :ivar str url: + The full queue service endpoint URL, including SAS token if used. This could be + either the primary endpoint, or the secondard endpint depending on the current `location_mode`. + :ivar str primary_endpoint: + The full primary endpoint URL. + :ivar str primary_hostname: + The hostname of the primary endpoint. + :ivar str secondary_endpoint: + The full secondard endpoint URL if configured. If not available + a ValueError will be raised. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + :ivar str secondary_hostname: + The hostname of the secondary endpoint. If not available this + will be None. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + :ivar str location_mode: + The location mode that the client is currently using. By default + this will be "primary". Options include "primary" and "secondary". + :param str account_url: + The URL to the queue service endpoint. Any other entities included + in the URL path (e.g. queue) will be discarded. This URL can be optionally + authenticated with a SAS token. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, and account + shared access key, or an instance of a TokenCredentials class from azure.identity. + + Example: + .. literalinclude:: ../tests/test_queue_samples_authentication.py + :start-after: [START create_queue_service_client] + :end-before: [END create_queue_service_client] + :language: python + :dedent: 8 + :caption: Creating the QueueServiceClient with an account url and credential. + """ + + def __init__( + self, account_url, # type: str + credential=None, # type: Optional[Any] + loop=None, # type: Any + **kwargs # type: Any + ): + # type: (...) -> None + kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) + super(QueueServiceClient, self).__init__( # type: ignore + account_url, + credential=credential, + loop=loop, + **kwargs) + self._client = AzureQueueStorage(url=self.url, pipeline=self._pipeline, loop=loop) # type: ignore + self._loop = loop + + async def get_service_stats(self, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[Any]) -> Dict[str, Any] + """Retrieves statistics related to replication for the Queue service. + + It is only available when read-access geo-redundant replication is enabled for + the storage account. + + With geo-redundant replication, Azure Storage maintains your data durable + in two locations. In both locations, Azure Storage constantly maintains + multiple healthy replicas of your data. The location where you read, + create, update, or delete data is the primary storage account location. + The primary location exists in the region you choose at the time you + create an account via the Azure Management Azure classic portal, for + example, North Central US. The location to which your data is replicated + is the secondary location. The secondary location is automatically + determined based on the location of the primary; it is in a second data + center that resides in the same region as the primary location. Read-only + access is available from the secondary location, if read-access geo-redundant + replication is enabled for your storage account. + + :param int timeout: + The timeout parameter is expressed in seconds. + :return: The queue service stats. + :rtype: ~azure.storage.queue._generated.models._models.StorageServiceStats + """ + try: + return await self._client.service.get_statistics( # type: ignore + timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) + except StorageErrorException as error: + process_storage_error(error) + + async def get_service_properties(self, timeout=None, **kwargs): # type: ignore + # type: (Optional[int], Optional[Any]) -> Dict[str, Any] + """Gets the properties of a storage account's Queue service, including + Azure Storage Analytics. + + :param int timeout: + The timeout parameter is expressed in seconds. + :rtype: ~azure.storage.queue._generated.models._models.StorageServiceProperties + + Example: + .. literalinclude:: ../tests/test_queue_samples_service.py + :start-after: [START get_queue_service_properties] + :end-before: [END get_queue_service_properties] + :language: python + :dedent: 8 + :caption: Getting queue service properties. + """ + try: + return await self._client.service.get_properties(timeout=timeout, **kwargs) # type: ignore + except StorageErrorException as error: + process_storage_error(error) + + async def set_service_properties( # type: ignore + self, logging=None, # type: Optional[Logging] + hour_metrics=None, # type: Optional[Metrics] + minute_metrics=None, # type: Optional[Metrics] + cors=None, # type: Optional[List[CorsRule]] + timeout=None, # type: Optional[int] + **kwargs + ): + # type: (...) -> None + """Sets the properties of a storage account's Queue service, including + Azure Storage Analytics. + + If an element (e.g. Logging) is left as None, the + existing settings on the service for that functionality are preserved. + + :param logging: + Groups the Azure Analytics Logging settings. + :type logging: ~azure.storage.queue.models.Logging + :param hour_metrics: + The hour metrics settings provide a summary of request + statistics grouped by API in hourly aggregates for queues. + :type hour_metrics: ~azure.storage.queue.models.Metrics + :param minute_metrics: + The minute metrics settings provide request statistics + for each minute for queues. + :type minute_metrics: ~azure.storage.queue.models.Metrics + :param cors: + You can include up to five CorsRule elements in the + list. If an empty list is specified, all CORS rules will be deleted, + and CORS will be disabled for the service. + :type cors: list(:class:`~azure.storage.queue.models.CorsRule`) + :param int timeout: + The timeout parameter is expressed in seconds. + :rtype: None + + Example: + .. literalinclude:: ../tests/test_queue_samples_service.py + :start-after: [START set_queue_service_properties] + :end-before: [END set_queue_service_properties] + :language: python + :dedent: 8 + :caption: Setting queue service properties. + """ + props = StorageServiceProperties( + logging=logging, + hour_metrics=hour_metrics, + minute_metrics=minute_metrics, + cors=cors + ) + try: + return await self._client.service.set_properties(props, timeout=timeout, **kwargs) # type: ignore + except StorageErrorException as error: + process_storage_error(error) + + def list_queues( + self, name_starts_with=None, # type: Optional[str] + include_metadata=False, # type: Optional[bool] + marker=None, # type: Optional[str] + results_per_page=None, # type: Optional[int] + timeout=None, # type: Optional[int] + **kwargs + ): + # type: (...) -> QueuePropertiesPaged + """Returns a generator to list the queues under the specified account. + + The generator will lazily follow the continuation tokens returned by + the service and stop when all queues have been returned. + + :param str name_starts_with: + Filters the results to return only queues whose names + begin with the specified prefix. + :param bool include_metadata: + Specifies that queue metadata be returned in the response. + :param str marker: + An opaque continuation token. This value can be retrieved from the + next_marker field of a previous generator object. If specified, + this generator will begin returning results from this point. + :param int results_per_page: + The maximum number of queue names to retrieve per API + call. If the request does not specify the server will return up to 5,000 items. + :param int timeout: + The server timeout, expressed in seconds. This function may make multiple + calls to the service in which case the timeout value specified will be + applied to each individual call. + :returns: An iterable (auto-paging) of QueueProperties. + :rtype: ~azure.core.queue.models.QueuePropertiesPaged + + Example: + .. literalinclude:: ../tests/test_queue_samples_service.py + :start-after: [START qsc_list_queues] + :end-before: [END qsc_list_queues] + :language: python + :dedent: 12 + :caption: List queues in the service. + """ + include = ['metadata'] if include_metadata else None + command = functools.partial( + self._client.service.list_queues_segment, + prefix=name_starts_with, + include=include, + timeout=timeout, + **kwargs) + return QueuePropertiesPaged( + command, prefix=name_starts_with, results_per_page=results_per_page, marker=marker) + + async def create_queue( # type: ignore + self, name, # type: str + metadata=None, # type: Optional[Dict[str, str]] + timeout=None, # type: Optional[int] + **kwargs + ): + # type: (...) -> QueueClient + """Creates a new queue under the specified account. + + If a queue with the same name already exists, the operation fails. + Returns a client with which to interact with the newly created queue. + + :param str name: The name of the queue to create. + :param metadata: + A dict with name_value pairs to associate with the + queue as metadata. Example: {'Category': 'test'} + :type metadata: dict(str, str) + :param int timeout: + The timeout parameter is expressed in seconds. + :rtype: ~azure.storage.queue.queue_client.QueueClient + + Example: + .. literalinclude:: ../tests/test_queue_samples_service.py + :start-after: [START qsc_create_queue] + :end-before: [END qsc_create_queue] + :language: python + :dedent: 8 + :caption: Create a queue in the service. + """ + queue = self.get_queue_client(name) + await queue.create_queue( + metadata=metadata, timeout=timeout, **kwargs) + return queue + + async def delete_queue( # type: ignore + self, queue, # type: Union[QueueProperties, str] + timeout=None, # type: Optional[int] + **kwargs + ): + # type: (...) -> None + """Deletes the specified queue and any messages it contains. + + When a queue is successfully deleted, it is immediately marked for deletion + and is no longer accessible to clients. The queue is later removed from + the Queue service during garbage collection. + + Note that deleting a queue is likely to take at least 40 seconds to complete. + If an operation is attempted against the queue while it was being deleted, + an :class:`HttpResponseError` will be thrown. + + :param queue: + The queue to delete. This can either be the name of the queue, + or an instance of QueueProperties. + :type queue: str or ~azure.storage.queue.models.QueueProperties + :param int timeout: + The timeout parameter is expressed in seconds. + :rtype: None + + Example: + .. literalinclude:: ../tests/test_queue_samples_service.py + :start-after: [START qsc_delete_queue] + :end-before: [END qsc_delete_queue] + :language: python + :dedent: 12 + :caption: Delete a queue in the service. + """ + queue_client = self.get_queue_client(queue) + await queue_client.delete_queue(timeout=timeout, **kwargs) + + def get_queue_client(self, queue, **kwargs): + # type: (Union[QueueProperties, str], Optional[Any]) -> QueueClient + """Get a client to interact with the specified queue. + + The queue need not already exist. + + :param queue: + The queue. This can either be the name of the queue, + or an instance of QueueProperties. + :type queue: str or ~azure.storage.queue.models.QueueProperties + :returns: A :class:`~azure.core.queue.queue_client.QueueClient` object. + :rtype: ~azure.core.queue.queue_client.QueueClient + + Example: + .. literalinclude:: ../tests/test_queue_samples_service.py + :start-after: [START get_queue_client] + :end-before: [END get_queue_client] + :language: python + :dedent: 8 + :caption: Get the queue client. + """ + return QueueClient( + self.url, queue=queue, credential=self.credential, key_resolver_function=self.key_resolver_function, + require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, + _pipeline=self._pipeline, _configuration=self._config, _location_mode=self._location_mode, + _hosts=self._hosts, loop=self._loop, **kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/models.py index d9349c4fdb63..043ff06ad8cd 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/models.py @@ -8,7 +8,9 @@ from typing import List # pylint: disable=unused-import from azure.core.paging import Paged -from ._shared.response_handlers import return_context_and_deserialized, process_storage_error +from ._shared.response_handlers import ( + process_storage_error, + return_context_and_deserialized) from ._shared.models import DictMixin from ._generated.models import StorageErrorException from ._generated.models import AccessPolicy as GenAccessPolicy diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/queue_client.py index 22aa8aad7b17..8815bb61007a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/queue_client.py @@ -23,11 +23,8 @@ process_storage_error, return_response_headers, return_headers_and_deserialized) -from ._queue_utils import ( - TextXMLEncodePolicy, - TextXMLDecodePolicy, - deserialize_queue_properties, - deserialize_queue_creation) +from ._message_encoding import TextXMLEncodePolicy, TextXMLDecodePolicy +from ._deserialize import deserialize_queue_properties, deserialize_queue_creation from ._generated import AzureQueueStorage from ._generated.models import StorageErrorException, SignedIdentifier from ._generated.models import QueueMessage as GenQueueMessage diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/queue_service_client.py index 105074a61b11..c4eef8ff21e1 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/queue_service_client.py @@ -354,7 +354,7 @@ def create_queue( **kwargs ): # type: (...) -> QueueClient - """Creates a new queue under the specified account. + """Creates a new queue under the specified account. If a queue with the same name already exists, the operation fails. Returns a client with which to interact with the newly created queue. diff --git a/sdk/storage/azure-storage-queue/conftest.py b/sdk/storage/azure-storage-queue/conftest.py new file mode 100644 index 000000000000..330109f55cd3 --- /dev/null +++ b/sdk/storage/azure-storage-queue/conftest.py @@ -0,0 +1,16 @@ +# coding: utf-8 + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import pytest + + +# Ignore async tests for Python < 3.5 +collect_ignore_glob = [] +if sys.version_info < (3, 5): + collect_ignore_glob.append("tests/*_async.py") diff --git a/sdk/storage/azure-storage-queue/tests/settings_fake.py b/sdk/storage/azure-storage-queue/tests/settings_fake.py deleted file mode 100644 index c4fe5917a862..000000000000 --- a/sdk/storage/azure-storage-queue/tests/settings_fake.py +++ /dev/null @@ -1,55 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -# NOTE: these keys are fake, but valid base-64 data, they were generated using: -# base64.b64encode(os.urandom(64)) - -STORAGE_ACCOUNT_NAME = "storagename" -QUEUE_NAME = "pythonqueue" -STORAGE_ACCOUNT_KEY = "NzhL3hKZbJBuJ2484dPTR+xF30kYaWSSCbs2BzLgVVI1woqeST/1IgqaLm6QAOTxtGvxctSNbIR/1hW8yH+bJg==" -BLOB_STORAGE_ACCOUNT_NAME = "blobstoragename" -BLOB_STORAGE_ACCOUNT_KEY = "NzhL3hKZbJBuJ2484dPTR+xF30kYaWSSCbs2BzLgVVI1woqeST/1IgqaLm6QAOTxtGvxctSNbIR/1hW8yH+bJg==" -REMOTE_STORAGE_ACCOUNT_NAME = "storagename" -REMOTE_STORAGE_ACCOUNT_KEY = "NzhL3hKZbJBuJ2484dPTR+xF30kYaWSSCbs2BzLgVVI1woqeST/1IgqaLm6QAOTxtGvxctSNbIR/1hW8yH+bJg==" -PREMIUM_STORAGE_ACCOUNT_NAME = "premiumstoragename" -PREMIUM_STORAGE_ACCOUNT_KEY = "NzhL3hKZbJBuJ2484dPTR+xF30kYaWSSCbs2BzLgVVI1woqeST/1IgqaLm6QAOTxtGvxctSNbIR/1hW8yH+bJg==" -OAUTH_STORAGE_ACCOUNT_NAME = "oauthstoragename" -OAUTH_STORAGE_ACCOUNT_KEY = "XBB/YoZ41bDFBW1VcgCBNYmA1PDlc3NvQQaCk2rb/JtBoMBlekznQwAzDJHvZO1gJmCh8CUT12Gv3aCkWaDeGA==" - -# Configurations related to Active Directory, which is used to obtain a token credential -ACTIVE_DIRECTORY_APPLICATION_ID = "68390a19-a897-236b-b453-488abf67b4fc" -ACTIVE_DIRECTORY_APPLICATION_SECRET = "3Ujhg7pzkOeE7flc6Z187ugf5/cJnszGPjAiXmcwhaY=" -ACTIVE_DIRECTORY_TENANT_ID = "32f988bf-54f1-15af-36ab-2d7cd364db47" - -# Use instead of STORAGE_ACCOUNT_NAME and STORAGE_ACCOUNT_KEY if custom settings are needed -CONNECTION_STRING = "DefaultEndpointsProtocol=https;AccountName=storagename;AccountKey=NzhL3hKZbJBuJ2484dPTR+xF30kYaWSSCbs2BzLgVVI1woqeST/1IgqaLm6QAOTxtGvxctSNbIR/1hW8yH+bJg==;EndpointSuffix=core.windows.net" -BLOB_CONNECTION_STRING = "DefaultEndpointsProtocol=https;AccountName=blobstoragename;AccountKey=NzhL3hKZbJBuJ2484dPTR+xF30kYaWSSCbs2BzLgVVI1woqeST/1IgqaLm6QAOTxtGvxctSNbIR/1hW8yH+bJg==;EndpointSuffix=core.windows.net" -PREMIUM_CONNECTION_STRING = "DefaultEndpointsProtocol=https;AccountName=premiumstoragename;AccountKey=NzhL3hKZbJBuJ2484dPTR+xF30kYaWSSCbs2BzLgVVI1woqeST/1IgqaLm6QAOTxtGvxctSNbIR/1hW8yH+bJg==;EndpointSuffix=core.windows.net" -# Use 'https' or 'http' protocol for sending requests, 'https' highly recommended -PROTOCOL = "https" - -# Set to true to target the development storage emulator -IS_EMULATED = False - -# Set to true if server side file encryption is enabled -IS_SERVER_SIDE_FILE_ENCRYPTION_ENABLED = True - -# Decide which test mode to run against. Possible options: -# - Playback: run against stored recordings -# - Record: run tests against live storage and update recordings -# - RunLiveNoRecord: run tests against live storage without altering recordings -TEST_MODE = 'Playback' - -# Set to true to enable logging for the tests -# logging is not enabled by default because it pollutes the CI logs -ENABLE_LOGGING = False - -# Set up proxy support -USE_PROXY = False -PROXY_HOST = "192.168.15.116" -PROXY_PORT = "8118" -PROXY_USER = "" -PROXY_PASSWORD = "" diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_async.py new file mode 100644 index 000000000000..718839c670f0 --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_async.py @@ -0,0 +1,1116 @@ +# coding: utf-8 + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +import pytest +import asyncio +from dateutil.tz import tzutc +from datetime import ( + datetime, + timedelta, + date, +) + +from azure.core.exceptions import ( + HttpResponseError, + ResourceNotFoundError, + ResourceExistsError) + +from azure.storage.queue.aio import QueueServiceClient, QueueClient +from azure.storage.queue import ( + QueuePermissions, + AccessPolicy, + ResourceTypes, + AccountPermissions, +) + +from queuetestcase import ( + QueueTestCase, + TestMode, + record, + LogCaptured, +) + +# ------------------------------------------------------------------------------ +TEST_QUEUE_PREFIX = 'pythonqueue' + + +# ------------------------------------------------------------------------------ + + +class StorageQueueTestAsync(QueueTestCase): + def setUp(self): + super(StorageQueueTestAsync, self).setUp() + + queue_url = self._get_queue_url() + credentials = self._get_shared_key_credential() + self.qsc = QueueServiceClient(account_url=queue_url, credential=credentials) + self.test_queues = [] + + def tearDown(self): + if not self.is_playback(): + loop = asyncio.get_event_loop() + for queue in self.test_queues: + try: + loop.run_until_complete(queue.delete_queue()) + except: + pass + return super(StorageQueueTestAsync, self).tearDown() + + # --Helpers----------------------------------------------------------------- + def _get_queue_reference(self, prefix=TEST_QUEUE_PREFIX): + queue_name = self.get_resource_name(prefix) + queue = self.qsc.get_queue_client(queue_name) + self.test_queues.append(queue) + return queue + + async def _create_queue(self, prefix=TEST_QUEUE_PREFIX): + queue = self._get_queue_reference(prefix) + created = await queue.create_queue() + return queue + + # --Test cases for queues ---------------------------------------------- + async def _test_create_queue(self): + # Action + queue_client = self._get_queue_reference() + created = await queue_client.create_queue() + + # Asserts + self.assertTrue(created) + + def test_create_queue(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_create_queue()) + + async def _test_create_queue_fail_on_exist(self): + # Action + queue_client = self._get_queue_reference() + created = await queue_client.create_queue() + with self.assertRaises(ResourceExistsError): + await queue_client.create_queue() + + # Asserts + self.assertTrue(created) + + def test_create_queue_fail_on_exist(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_create_queue_fail_on_exist()) + + async def _test_create_queue_fail_on_exist_different_metadata(self): + # Action + queue_client = self._get_queue_reference() + created = await queue_client.create_queue() + with self.assertRaises(ResourceExistsError): + await queue_client.create_queue(metadata={"val": "value"}) + + # Asserts + self.assertTrue(created) + + def test_create_queue_fail_on_exist_different_metadata(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_create_queue_fail_on_exist_different_metadata()) + + async def _test_create_queue_with_options(self): + # Action + queue_client = self._get_queue_reference() + await queue_client.create_queue( + metadata={'val1': 'test', 'val2': 'blah'}) + props = await queue_client.get_queue_properties() + + # Asserts + self.assertEqual(0, props.approximate_message_count) + self.assertEqual(2, len(props.metadata)) + self.assertEqual('test', props.metadata['val1']) + self.assertEqual('blah', props.metadata['val2']) + + def test_create_queue_with_options(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_create_queue_with_options()) + + async def _test_delete_non_existing_queue(self): + # Action + queue_client = self._get_queue_reference() + + # Asserts + with self.assertRaises(ResourceNotFoundError): + await queue_client.delete_queue() + + def test_delete_non_existing_queue(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_delete_non_existing_queue()) + + async def _test_delete_existing_queue_fail_not_exist(self): + # Action + queue_client = self._get_queue_reference() + + created = await queue_client.create_queue() + deleted = await queue_client.delete_queue() + + # Asserts + self.assertIsNone(deleted) + + def test_delete_existing_queue_fail_not_exist(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_delete_existing_queue_fail_not_exist()) + + async def _test_list_queues(self): + # Action + queues = [] + async for q in self.qsc.list_queues(): + queues.append(q) + + # Asserts + self.assertIsNotNone(queues) + self.assertTrue(len(self.test_queues) <= len(queues)) + + def test_list_queues(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_list_queues()) + + async def _test_list_queues_with_options(self): + # Arrange + prefix = 'listqueue' + for i in range(0, 4): + await self._create_queue(prefix + str(i)) + + # Action + generator1 = self.qsc.list_queues(name_starts_with=prefix, results_per_page=3) + await generator1.__anext__() + queues1 = generator1.current_page + + generator2 = self.qsc.list_queues( + name_starts_with=prefix, + marker=generator1.next_marker, + include_metadata=True) + await generator2.__anext__() + queues2 = generator2.current_page + + # Asserts + self.assertIsNotNone(queues1) + self.assertEqual(3, len(queues1)) + self.assertIsNotNone(queues1[0]) + self.assertIsNone(queues1[0].metadata) + self.assertNotEqual('', queues1[0].name) + # Asserts + self.assertIsNotNone(queues2) + self.assertTrue(len(self.test_queues) - 3 <= len(queues2)) + self.assertIsNotNone(queues2[0]) + self.assertNotEqual('', queues2[0].name) + + def test_list_queues_with_options_async(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_list_queues_with_options()) + + async def _test_list_queues_with_metadata(self): + # Action + queue = await self._create_queue() + await queue.set_queue_metadata(metadata={'val1': 'test', 'val2': 'blah'}) + + listed_queue = [] + async for q in self.qsc.list_queues( + name_starts_with=queue.queue_name, + results_per_page=1, + include_metadata=True): + listed_queue.append(q) + listed_queue = listed_queue[0] + + # Asserts + self.assertIsNotNone(listed_queue) + self.assertEqual(queue.queue_name, listed_queue.name) + self.assertIsNotNone(listed_queue.metadata) + self.assertEqual(len(listed_queue.metadata), 2) + self.assertEqual(listed_queue.metadata['val1'], 'test') + + def test_list_queues_with_metadata(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_list_queues_with_metadata()) + + async def _test_set_queue_metadata(self): + # Action + metadata = {'hello': 'world', 'number': '43'} + queue = await self._create_queue() + + # Act + await queue.set_queue_metadata(metadata) + metadata_from_response = await queue.get_queue_properties() + md = metadata_from_response.metadata + # Assert + self.assertDictEqual(md, metadata) + + def test_set_queue_metadata(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_queue_metadata()) + + async def _test_get_queue_metadata_message_count(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + props = await queue_client.get_queue_properties() + + # Asserts + self.assertTrue(props.approximate_message_count >= 1) + self.assertEqual(0, len(props.metadata)) + + def test_get_queue_metadata_message_count(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_queue_metadata_message_count()) + + async def _test_queue_exists(self): + # Arrange + queue = await self._create_queue() + + # Act + exists = await queue.get_queue_properties() + + # Assert + self.assertTrue(exists) + + def test_queue_exists(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_queue_exists()) + + async def _test_queue_not_exists(self): + # Arrange + queue = self.qsc.get_queue_client(self.get_resource_name('missing')) + # Act + with self.assertRaises(ResourceNotFoundError): + await queue.get_queue_properties() + + # Assert + + def test_queue_not_exists(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_queue_not_exists()) + + async def _test_put_message(self): + # Action. No exception means pass. No asserts needed. + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + await queue_client.enqueue_message(u'message2') + await queue_client.enqueue_message(u'message3') + message = await queue_client.enqueue_message(u'message4') + + # Asserts + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertIsInstance(message.insertion_time, datetime) + self.assertIsInstance(message.expiration_time, datetime) + self.assertNotEqual('', message.pop_receipt) + self.assertEqual(u'message4', message.content) + + def test_put_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_put_message()) + + async def _test_put_message_large_time_to_live(self): + # Arrange + queue_client = await self._create_queue() + # There should be no upper bound on a queue message's time to live + await queue_client.enqueue_message(u'message1', time_to_live=1024*1024*1024) + + # Act + messages = await queue_client.peek_messages() + + # Assert + self.assertGreaterEqual( + messages[0].expiration_time, + messages[0].insertion_time + timedelta(seconds=1024 * 1024 * 1024 - 3600)) + + def test_put_message_large_time_to_live(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_put_message_large_time_to_live()) + + async def _test_put_message_infinite_time_to_live(self): + # Arrange + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1', time_to_live=-1) + + # Act + messages = await queue_client.peek_messages() + + # Assert + self.assertEqual(messages[0].expiration_time.year, date.max.year) + + def test_put_message_infinite_time_to_live(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_put_message_infinite_time_to_live()) + + async def _test_get_messages(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + await queue_client.enqueue_message(u'message2') + await queue_client.enqueue_message(u'message3') + await queue_client.enqueue_message(u'message4') + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + if len(messages): + break + message = messages[0] + # Asserts + self.assertIsNotNone(message) + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertEqual(u'message1', message.content) + self.assertNotEqual('', message.pop_receipt) + self.assertEqual(1, message.dequeue_count) + + self.assertIsInstance(message.insertion_time, datetime) + self.assertIsInstance(message.expiration_time, datetime) + self.assertIsInstance(message.time_next_visible, datetime) + + def test_get_messages(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_messages()) + + async def _test_get_messages_with_options(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + await queue_client.enqueue_message(u'message2') + await queue_client.enqueue_message(u'message3') + await queue_client.enqueue_message(u'message4') + result = queue_client.receive_messages(messages_per_page=4, visibility_timeout=20) + await result.__anext__() + + # Asserts + self.assertIsNotNone(result) + self.assertEqual(4, len(result.current_page)) + + for message in result.current_page: + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertNotEqual('', message.content) + self.assertNotEqual('', message.pop_receipt) + self.assertEqual(1, message.dequeue_count) + self.assertNotEqual('', message.insertion_time) + self.assertNotEqual('', message.expiration_time) + self.assertNotEqual('', message.time_next_visible) + + def test_get_messages_with_options(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_messages_with_options()) + + async def _test_peek_messages(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + await queue_client.enqueue_message(u'message2') + await queue_client.enqueue_message(u'message3') + await queue_client.enqueue_message(u'message4') + result = await queue_client.peek_messages() + + # Asserts + self.assertIsNotNone(result) + self.assertEqual(1, len(result)) + message = result[0] + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertNotEqual('', message.content) + self.assertIsNone(message.pop_receipt) + self.assertEqual(0, message.dequeue_count) + self.assertNotEqual('', message.insertion_time) + self.assertNotEqual('', message.expiration_time) + self.assertIsNone(message.time_next_visible) + + def test_peek_messages(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_peek_messages()) + + async def _test_peek_messages_with_options(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + await queue_client.enqueue_message(u'message2') + await queue_client.enqueue_message(u'message3') + await queue_client.enqueue_message(u'message4') + result = await queue_client.peek_messages(max_messages=4) + + # Asserts + self.assertIsNotNone(result) + self.assertEqual(4, len(result)) + for message in result: + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertNotEqual('', message.content) + self.assertIsNone(message.pop_receipt) + self.assertEqual(0, message.dequeue_count) + self.assertNotEqual('', message.insertion_time) + self.assertNotEqual('', message.expiration_time) + self.assertIsNone(message.time_next_visible) + + def test_peek_messages_with_options(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_peek_messages_with_options()) + + async def _test_clear_messages(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + await queue_client.enqueue_message(u'message2') + await queue_client.enqueue_message(u'message3') + await queue_client.enqueue_message(u'message4') + await queue_client.clear_messages() + result = await queue_client.peek_messages() + + # Asserts + self.assertIsNotNone(result) + self.assertEqual(0, len(result)) + + def test_clear_messages(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_clear_messages()) + + async def _test_delete_message(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + await queue_client.enqueue_message(u'message2') + await queue_client.enqueue_message(u'message3') + await queue_client.enqueue_message(u'message4') + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + await queue_client.delete_message(m) + async for m in queue_client.receive_messages(): + messages.append(m) + # Asserts + self.assertIsNotNone(messages) + self.assertEqual(3, len(messages)-1) + + def test_delete_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_delete_message()) + + async def _test_update_message(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + list_result1 = messages[0] + message = await queue_client.update_message( + list_result1.id, + pop_receipt=list_result1.pop_receipt, + visibility_timeout=0) + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + list_result2 = messages[0] + + # Asserts + # Update response + self.assertIsNotNone(message) + self.assertIsNotNone(message.pop_receipt) + self.assertIsNotNone(message.time_next_visible) + self.assertIsInstance(message.time_next_visible, datetime) + + # Get response + self.assertIsNotNone(list_result2) + message = list_result2 + self.assertIsNotNone(message) + self.assertEqual(list_result1.id, message.id) + self.assertEqual(u'message1', message.content) + self.assertEqual(2, message.dequeue_count) + self.assertIsNotNone(message.pop_receipt) + self.assertIsNotNone(message.insertion_time) + self.assertIsNotNone(message.expiration_time) + self.assertIsNotNone(message.time_next_visible) + + def test_update_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_update_message()) + + async def _test_update_message_content(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + list_result1 = messages[0] + message = await queue_client.update_message( + list_result1.id, + pop_receipt=list_result1.pop_receipt, + visibility_timeout=0, + content=u'new text') + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + list_result2 = messages[0] + + # Asserts + # Update response + self.assertIsNotNone(message) + self.assertIsNotNone(message.pop_receipt) + self.assertIsNotNone(message.time_next_visible) + self.assertIsInstance(message.time_next_visible, datetime) + + # Get response + self.assertIsNotNone(list_result2) + message = list_result2 + self.assertIsNotNone(message) + self.assertEqual(list_result1.id, message.id) + self.assertEqual(u'new text', message.content) + self.assertEqual(2, message.dequeue_count) + self.assertIsNotNone(message.pop_receipt) + self.assertIsNotNone(message.insertion_time) + self.assertIsNotNone(message.expiration_time) + self.assertIsNotNone(message.time_next_visible) + + def test_update_message_content(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_update_message_content()) + + async def _test_account_sas(self): + # SAS URL is calculated from storage key, so this test runs live only + if TestMode.need_recording_file(self.test_mode): + return + + # Arrange + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + token = self.qsc.generate_shared_access_signature( + ResourceTypes.OBJECT, + AccountPermissions.READ, + datetime.utcnow() + timedelta(hours=1), + datetime.utcnow() - timedelta(minutes=5) + ) + + # Act + service = QueueServiceClient( + account_url=self.qsc.url, + credential=token, + ) + new_queue_client = service.get_queue_client(queue_client.queue_name) + result = await new_queue_client.peek_messages() + + # Assert + self.assertIsNotNone(result) + self.assertEqual(1, len(result)) + message = result[0] + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertEqual(u'message1', message.content) + + def test_account_sas(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_account_sas()) + + async def _test_token_credential(self): + pytest.skip("") + token_credential = self.generate_oauth_token() + + # Action 1: make sure token works + service = QueueServiceClient(self._get_oauth_queue_url(), credential=token_credential) + queues = await service.get_service_properties() + self.assertIsNotNone(queues) + + # Action 2: change token value to make request fail + fake_credential = self.generate_fake_token() + service = QueueServiceClient(self._get_oauth_queue_url(), credential=fake_credential) + with self.assertRaises(ClientAuthenticationError): + queue_li = await service.list_queues() + list(queue_li) + + # Action 3: update token to make it working again + service = QueueServiceClient(self._get_oauth_queue_url(), credential=token_credential) + queue_li = await service.list_queues() + queues = list(queue_li) + self.assertIsNotNone(queues) + + def test_token_credential(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_token_credential()) + + async def _test_sas_read(self): + # SAS URL is calculated from storage key, so this test runs live only + if TestMode.need_recording_file(self.test_mode): + return + + # Arrange + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + token = queue_client.generate_shared_access_signature( + QueuePermissions.READ, + datetime.utcnow() + timedelta(hours=1), + datetime.utcnow() - timedelta(minutes=5) + ) + + # Act + service = QueueClient( + queue_url=queue_client.url, + credential=token, + ) + result = await service.peek_messages() + + # Assert + self.assertIsNotNone(result) + self.assertEqual(1, len(result)) + message = result[0] + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertEqual(u'message1', message.content) + + def test_sas_read(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_sas_read()) + + async def _test_sas_add(self): + # SAS URL is calculated from storage key, so this test runs live only + if TestMode.need_recording_file(self.test_mode): + return + + # Arrange + queue_client = await self._create_queue() + token = queue_client.generate_shared_access_signature( + QueuePermissions.ADD, + datetime.utcnow() + timedelta(hours=1), + ) + + # Act + service = QueueClient( + queue_url=queue_client.url, + credential=token, + ) + result = await service.enqueue_message(u'addedmessage') + + # Assert + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + result = messages[0] + self.assertEqual(u'addedmessage', result.content) + + def test_sas_add(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_sas_add()) + + async def _test_sas_update(self): + # SAS URL is calculated from storage key, so this test runs live only + if TestMode.need_recording_file(self.test_mode): + return + + # Arrange + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + token = queue_client.generate_shared_access_signature( + QueuePermissions.UPDATE, + datetime.utcnow() + timedelta(hours=1), + ) + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + result = messages[0] + + # Act + service = QueueClient( + queue_url=queue_client.url, + credential=token, + ) + await service.update_message( + result.id, + pop_receipt=result.pop_receipt, + visibility_timeout=0, + content=u'updatedmessage1', + ) + + # Assert + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + result = messages[0] + self.assertEqual(u'updatedmessage1', result.content) + + def test_sas_update(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_sas_update()) + + async def _test_sas_process(self): + # SAS URL is calculated from storage key, so this test runs live only + if TestMode.need_recording_file(self.test_mode): + return + + # Arrange + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + token = queue_client.generate_shared_access_signature( + QueuePermissions.PROCESS, + datetime.utcnow() + timedelta(hours=1), + ) + + # Act + service = QueueClient( + queue_url=queue_client.url, + credential=token, + ) + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + message = messages[0] + + # Assert + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertEqual(u'message1', message.content) + + def test_sas_process(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_sas_process()) + + async def _test_sas_signed_identifier(self): + # SAS URL is calculated from storage key, so this test runs live only + if TestMode.need_recording_file(self.test_mode): + return + + # Arrange + access_policy = AccessPolicy() + access_policy.start = datetime.utcnow() - timedelta(hours=1) + access_policy.expiry = datetime.utcnow() + timedelta(hours=1) + access_policy.permission = QueuePermissions.READ + + identifiers = {'testid': access_policy} + + queue_client = await self._create_queue() + resp = await queue_client.set_queue_access_policy(identifiers) + + await queue_client.enqueue_message(u'message1') + + token = queue_client.generate_shared_access_signature( + policy_id='testid' + ) + + # Act + service = QueueClient( + queue_url=queue_client.url, + credential=token, + ) + result = await service.peek_messages() + + # Assert + self.assertIsNotNone(result) + self.assertEqual(1, len(result)) + message = result[0] + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertEqual(u'message1', message.content) + + def test_sas_signed_identifier(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_sas_signed_identifier()) + + async def _test_get_queue_acl(self): + # Arrange + queue_client = await self._create_queue() + + # Act + acl = await queue_client.get_queue_access_policy() + + # Assert + self.assertIsNotNone(acl) + self.assertEqual(len(acl), 0) + + def test_get_queue_acl(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_queue_acl()) + + async def _test_get_queue_acl_iter(self): + # Arrange + queue_client = await self._create_queue() + + # Act + acl = await queue_client.get_queue_access_policy() + for signed_identifier in acl: + pass + + # Assert + self.assertIsNotNone(acl) + self.assertEqual(len(acl), 0) + + def test_get_queue_acl_iter(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_queue_acl_iter()) + + async def _test_get_queue_acl_with_non_existing_queue(self): + # Arrange + queue_client = self._get_queue_reference() + + # Act + with self.assertRaises(ResourceNotFoundError): + await queue_client.get_queue_access_policy() + + # Assert + + def test_get_queue_acl_with_non_existing_queue(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_queue_acl_with_non_existing_queue()) + + async def _test_set_queue_acl(self): + # Arrange + queue_client = await self._create_queue() + + # Act + resp = await queue_client.set_queue_access_policy() + + # Assert + self.assertIsNone(resp) + acl = await queue_client.get_queue_access_policy() + self.assertIsNotNone(acl) + + def test_set_queue_acl(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_queue_acl()) + + async def _test_set_queue_acl_with_empty_signed_identifiers(self): + # Arrange + queue_client = await self._create_queue() + + # Act + await queue_client.set_queue_access_policy(signed_identifiers={}) + + # Assert + acl = await queue_client.get_queue_access_policy() + self.assertIsNotNone(acl) + self.assertEqual(len(acl), 0) + + def test_set_queue_acl_with_empty_signed_identifiers(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_queue_acl_with_empty_signed_identifiers()) + + async def _test_set_queue_acl_with_empty_signed_identifier(self): + # Arrange + queue_client = await self._create_queue() + + # Act + await queue_client.set_queue_access_policy(signed_identifiers={'empty': AccessPolicy()}) + + # Assert + acl = await queue_client.get_queue_access_policy() + self.assertIsNotNone(acl) + self.assertEqual(len(acl), 1) + self.assertIsNotNone(acl['empty']) + self.assertIsNone(acl['empty'].permission) + self.assertIsNone(acl['empty'].expiry) + self.assertIsNone(acl['empty'].start) + + def test_set_queue_acl_with_empty_signed_identifier(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_queue_acl_with_empty_signed_identifier()) + + async def _test_set_queue_acl_with_signed_identifiers(self): + # Arrange + queue_client = await self._create_queue() + + # Act + access_policy = AccessPolicy(permission=QueuePermissions.READ, + expiry=datetime.utcnow() + timedelta(hours=1), + start=datetime.utcnow() - timedelta(minutes=5)) + identifiers = {'testid': access_policy} + + resp = await queue_client.set_queue_access_policy(signed_identifiers=identifiers) + + # Assert + self.assertIsNone(resp) + acl = await queue_client.get_queue_access_policy() + self.assertIsNotNone(acl) + self.assertEqual(len(acl), 1) + self.assertTrue('testid' in acl) + + def test_set_queue_acl_with_signed_identifiers(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_queue_acl_with_signed_identifiers()) + + async def _test_set_queue_acl_too_many_ids(self): + # Arrange + queue_client = await self._create_queue() + + # Act + identifiers = dict() + for i in range(0, 16): + identifiers['id{}'.format(i)] = AccessPolicy() + + # Assert + with self.assertRaises(ValueError): + await queue_client.set_queue_access_policy(identifiers) + + def test_set_queue_acl_too_many_ids(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_queue_acl_too_many_ids()) + + async def _test_set_queue_acl_with_non_existing_queue(self): + # Arrange + queue_client = self._get_queue_reference() + + # Act + with self.assertRaises(ResourceNotFoundError): + await queue_client.set_queue_access_policy() + + # Assert + + def test_set_queue_acl_with_non_existing_queue(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_queue_acl_with_non_existing_queue()) + + async def _test_unicode_create_queue_unicode_name(self): + # Action + queue_name = u'啊齄丂狛狜' + + with self.assertRaises(HttpResponseError): + # not supported - queue name must be alphanumeric, lowercase + client = self.qsc.get_queue_client(queue_name) + await client.create_queue() + + # Asserts + + def test_unicode_create_queue_unicode_name(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_unicode_create_queue_unicode_name()) + + async def _test_unicode_get_messages_unicode_data(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1㚈') + message = None + async for m in queue_client.receive_messages(): + message = m + # Asserts + self.assertIsNotNone(message) + self.assertNotEqual('', message.id) + self.assertEqual(u'message1㚈', message.content) + self.assertNotEqual('', message.pop_receipt) + self.assertEqual(1, message.dequeue_count) + self.assertIsInstance(message.insertion_time, datetime) + self.assertIsInstance(message.expiration_time, datetime) + self.assertIsInstance(message.time_next_visible, datetime) + + def test_unicode_get_messages_unicode_data(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_unicode_get_messages_unicode_data()) + + async def _test_unicode_update_message_unicode_data(self): + # Action + queue_client = await self._create_queue() + await queue_client.enqueue_message(u'message1') + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + + list_result1 = messages[0] + list_result1.content = u'啊齄丂狛狜' + await queue_client.update_message(list_result1, visibility_timeout=0) + messages = [] + async for m in queue_client.receive_messages(): + messages.append(m) + # Asserts + message = messages[0] + self.assertIsNotNone(message) + self.assertEqual(list_result1.id, message.id) + self.assertEqual(u'啊齄丂狛狜', message.content) + self.assertNotEqual('', message.pop_receipt) + self.assertEqual(2, message.dequeue_count) + self.assertIsInstance(message.insertion_time, datetime) + self.assertIsInstance(message.expiration_time, datetime) + self.assertIsInstance(message.time_next_visible, datetime) + + def test_unicode_update_message_unicode_data(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_unicode_update_message_unicode_data()) + + +# ------------------------------------------------------------------------------ +if __name__ == '__main__': + unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_client_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_client_async.py new file mode 100644 index 000000000000..359b43929d68 --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_client_async.py @@ -0,0 +1,430 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +import pytest +import platform +import asyncio + +from azure.storage.queue.aio import ( + QueueServiceClient, + QueueClient +) +from queuetestcase import ( + QueueTestCase, + record, + TestMode +) + +# ------------------------------------------------------------------------------ +SERVICES = { + QueueServiceClient: 'queue', + QueueClient: 'queue', +} + +_CONNECTION_ENDPOINTS = {'queue': 'QueueEndpoint'} + +_CONNECTION_ENDPOINTS_SECONDARY = {'queue': 'QueueSecondaryEndpoint'} + +class StorageQueueClientTestAsync(QueueTestCase): + def setUp(self): + super(StorageQueueClientTestAsync, self).setUp() + self.account_name = self.settings.STORAGE_ACCOUNT_NAME + self.account_key = self.settings.STORAGE_ACCOUNT_KEY + self.sas_token = '?sv=2015-04-05&st=2015-04-29T22%3A18%3A26Z&se=2015-04-30T02%3A23%3A26Z&sr=b&sp=rw&sip=168.1.5.60-168.1.5.70&spr=https&sig=Z%2FRHIX5Xcg0Mq2rqI3OlWTjEg2tYkboXr1P9ZUXDtkk%3D' + self.token_credential = self.generate_oauth_token() + self.connection_string = self.settings.CONNECTION_STRING + + # --Helpers----------------------------------------------------------------- + def validate_standard_account_endpoints(self, service, url_type): + self.assertIsNotNone(service) + self.assertEqual(service.credential.account_name, self.account_name) + self.assertEqual(service.credential.account_key, self.account_key) + self.assertTrue('{}.{}.core.windows.net'.format(self.account_name, url_type) in service.url) + self.assertTrue('{}-secondary.{}.core.windows.net'.format(self.account_name, url_type) in service.secondary_endpoint) + + # --Direct Parameters Test Cases -------------------------------------------- + def test_create_service_with_key(self): + # Arrange + + for client, url in SERVICES.items(): + # Act + service = client( + self._get_queue_url(), credential=self.account_key, queue='foo') + + # Assert + self.validate_standard_account_endpoints(service, url) + self.assertEqual(service.scheme, 'https') + + def test_create_service_with_connection_string(self): + + for service_type in SERVICES.items(): + # Act + service = service_type[0].from_connection_string( + self.connection_string, queue="test") + + # Assert + self.validate_standard_account_endpoints(service, service_type[1]) + self.assertEqual(service.scheme, 'https') + + def test_create_service_with_sas(self): + # Arrange + + for service_type in SERVICES: + # Act + service = service_type( + self._get_queue_url(), credential=self.sas_token, queue='foo') + + # Assert + self.assertIsNotNone(service) + self.assertTrue(service.url.startswith('https://' + self.account_name + '.queue.core.windows.net')) + self.assertTrue(service.url.endswith(self.sas_token)) + self.assertIsNone(service.credential) + + def test_create_service_with_token(self): + for service_type in SERVICES: + # Act + service = service_type( + self._get_queue_url(), credential=self.token_credential, queue='foo') + + # Assert + self.assertIsNotNone(service) + self.assertTrue(service.url.startswith('https://' + self.account_name + '.queue.core.windows.net')) + self.assertEqual(service.credential, self.token_credential) + self.assertFalse(hasattr(service.credential, 'account_key')) + self.assertTrue(hasattr(service.credential, 'get_token')) + + def test_create_service_with_token_and_http(self): + for service_type in SERVICES: + # Act + with self.assertRaises(ValueError): + url = self._get_queue_url().replace('https', 'http') + service_type(url, credential=self.token_credential, queue='foo') + + def test_create_service_china(self): + # Arrange + + for service_type in SERVICES.items(): + # Act + url = self._get_queue_url().replace('core.windows.net', 'core.chinacloudapi.cn') + service = service_type[0]( + url, credential=self.account_key, queue='foo') + + # Assert + self.assertIsNotNone(service) + self.assertEqual(service.credential.account_name, self.account_name) + self.assertEqual(service.credential.account_key, self.account_key) + self.assertTrue(service.primary_endpoint.startswith( + 'https://{}.{}.core.chinacloudapi.cn'.format(self.account_name, service_type[1]))) + self.assertTrue(service.secondary_endpoint.startswith( + 'https://{}-secondary.{}.core.chinacloudapi.cn'.format(self.account_name, service_type[1]))) + + def test_create_service_protocol(self): + # Arrange + + for service_type in SERVICES.items(): + # Act + url = self._get_queue_url().replace('https', 'http') + service = service_type[0]( + url, credential=self.account_key, queue='foo') + + # Assert + self.validate_standard_account_endpoints(service, service_type[1]) + self.assertEqual(service.scheme, 'http') + + def test_create_service_empty_key(self): + # Arrange + QUEUE_SERVICES = [QueueServiceClient, QueueClient] + + for service_type in QUEUE_SERVICES: + # Act + with self.assertRaises(ValueError) as e: + test_service = service_type('testaccount', credential='', queue='foo') + + self.assertEqual( + str(e.exception), "You need to provide either a SAS token or an account key to authenticate.") + + def test_create_service_missing_arguments(self): + # Arrange + + for service_type in SERVICES: + # Act + with self.assertRaises(ValueError): + service = service_type(None) + # Assert + + def test_create_service_with_socket_timeout(self): + # Arrange + + for service_type in SERVICES.items(): + # Act + default_service = service_type[0]( + self._get_queue_url(), credential=self.account_key, queue='foo') + service = service_type[0]( + self._get_queue_url(), credential=self.account_key, + queue='foo', connection_timeout=22) + + # Assert + self.validate_standard_account_endpoints(service, service_type[1]) + self.assertEqual(service._config.connection.timeout, 22) + self.assertTrue(default_service._config.connection.timeout in [20, (20, 2000)]) + + # --Connection String Test Cases -------------------------------------------- + + def test_create_service_with_connection_string_key(self): + # Arrange + conn_string = 'AccountName={};AccountKey={};'.format(self.account_name, self.account_key) + + for service_type in SERVICES.items(): + # Act + service = service_type[0].from_connection_string(conn_string, queue='foo') + + # Assert + self.validate_standard_account_endpoints(service, service_type[1]) + self.assertEqual(service.scheme, 'https') + + def test_create_service_with_connection_string_sas(self): + # Arrange + conn_string = 'AccountName={};SharedAccessSignature={};'.format(self.account_name, self.sas_token) + + for service_type in SERVICES: + # Act + service = service_type.from_connection_string(conn_string, queue='foo') + + # Assert + self.assertIsNotNone(service) + self.assertTrue(service.url.startswith('https://' + self.account_name + '.queue.core.windows.net')) + self.assertTrue(service.url.endswith(self.sas_token)) + self.assertIsNone(service.credential) + + def test_create_service_with_connection_string_endpoint_protocol(self): + # Arrange + conn_string = 'AccountName={};AccountKey={};DefaultEndpointsProtocol=http;EndpointSuffix=core.chinacloudapi.cn;'.format( + self.account_name, self.account_key) + + for service_type in SERVICES.items(): + # Act + service = service_type[0].from_connection_string(conn_string, queue="foo") + + # Assert + self.assertIsNotNone(service) + self.assertEqual(service.credential.account_name, self.account_name) + self.assertEqual(service.credential.account_key, self.account_key) + self.assertTrue( + service.primary_endpoint.startswith( + 'http://{}.{}.core.chinacloudapi.cn/'.format(self.account_name, service_type[1]))) + self.assertTrue( + service.secondary_endpoint.startswith( + 'http://{}-secondary.{}.core.chinacloudapi.cn'.format(self.account_name, service_type[1]))) + self.assertEqual(service.scheme, 'http') + + def test_create_service_with_connection_string_emulated(self): + # Arrange + for service_type in SERVICES.items(): + conn_string = 'UseDevelopmentStorage=true;'.format(self.account_name, self.account_key) + + # Act + with self.assertRaises(ValueError): + service = service_type[0].from_connection_string(conn_string, queue="foo") + + def test_create_service_with_connection_string_custom_domain(self): + # Arrange + for service_type in SERVICES.items(): + conn_string = 'AccountName={};AccountKey={};QueueEndpoint=www.mydomain.com;'.format( + self.account_name, self.account_key) + + # Act + service = service_type[0].from_connection_string(conn_string, queue="foo") + + # Assert + self.assertIsNotNone(service) + self.assertEqual(service.credential.account_name, self.account_name) + self.assertEqual(service.credential.account_key, self.account_key) + self.assertTrue(service.primary_endpoint.startswith('https://www.mydomain.com/')) + self.assertTrue(service.secondary_endpoint.startswith('https://' + self.account_name + '-secondary.queue.core.windows.net')) + + def test_create_service_with_connection_string_custom_domain_trailing_slash(self): + # Arrange + for service_type in SERVICES.items(): + conn_string = 'AccountName={};AccountKey={};QueueEndpoint=www.mydomain.com/;'.format( + self.account_name, self.account_key) + + # Act + service = service_type[0].from_connection_string(conn_string, queue="foo") + + # Assert + self.assertIsNotNone(service) + self.assertEqual(service.credential.account_name, self.account_name) + self.assertEqual(service.credential.account_key, self.account_key) + self.assertTrue(service.primary_endpoint.startswith('https://www.mydomain.com/')) + self.assertTrue(service.secondary_endpoint.startswith('https://' + self.account_name + '-secondary.queue.core.windows.net')) + + + def test_create_service_with_connection_string_custom_domain_secondary_override(self): + # Arrange + for service_type in SERVICES.items(): + conn_string = 'AccountName={};AccountKey={};QueueEndpoint=www.mydomain.com/;'.format( + self.account_name, self.account_key) + + # Act + service = service_type[0].from_connection_string( + conn_string, secondary_hostname="www-sec.mydomain.com", queue="foo") + + # Assert + self.assertIsNotNone(service) + self.assertEqual(service.credential.account_name, self.account_name) + self.assertEqual(service.credential.account_key, self.account_key) + self.assertTrue(service.primary_endpoint.startswith('https://www.mydomain.com/')) + self.assertTrue(service.secondary_endpoint.startswith('https://www-sec.mydomain.com/')) + + + def test_create_service_with_connection_string_fails_if_secondary_without_primary(self): + for service_type in SERVICES.items(): + # Arrange + conn_string = 'AccountName={};AccountKey={};{}=www.mydomain.com;'.format( + self.account_name, self.account_key, + _CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])) + + # Act + + # Fails if primary excluded + with self.assertRaises(ValueError): + service = service_type[0].from_connection_string(conn_string, queue="foo") + + def test_create_service_with_connection_string_succeeds_if_secondary_with_primary(self): + for service_type in SERVICES.items(): + # Arrange + conn_string = 'AccountName={};AccountKey={};{}=www.mydomain.com;{}=www-sec.mydomain.com;'.format( + self.account_name, + self.account_key, + _CONNECTION_ENDPOINTS.get(service_type[1]), + _CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])) + + # Act + service = service_type[0].from_connection_string(conn_string, queue="foo") + + # Assert + self.assertIsNotNone(service) + self.assertEqual(service.credential.account_name, self.account_name) + self.assertEqual(service.credential.account_key, self.account_key) + self.assertTrue(service.primary_endpoint.startswith('https://www.mydomain.com/')) + self.assertTrue(service.secondary_endpoint.startswith('https://www-sec.mydomain.com/')) + + async def _test_request_callback_signed_header(self): + # Arrange + service = QueueServiceClient(self._get_queue_url(), credential=self.account_key) + name = self.get_resource_name('cont') + + # Act + try: + headers = {'x-ms-meta-hello': 'world'} + queue = await service.create_queue(name, headers=headers) + + # Assert + metadata_cr = await queue.get_queue_properties() + metadata = metadata_cr.metadata + self.assertEqual(metadata, {'hello': 'world'}) + finally: + service.delete_queue(name) + + def test_request_callback_signed_header(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_request_callback_signed_header()) + + async def _test_response_callback(self): + # Arrange + service = QueueServiceClient(self._get_queue_url(), credential=self.account_key) + name = self.get_resource_name('cont') + queue = service.get_queue_client(name) + + # Act + def callback(response): + response.http_response.status_code = 200 + + + # Assert + exists = await queue.get_queue_properties(raw_response_hook=callback) + self.assertTrue(exists) + + def test_response_callback(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_response_callback()) + + async def _test_user_agent_default(self): + service = QueueServiceClient(self._get_queue_url(), credential=self.account_key) + + def callback(response): + self.assertTrue('User-Agent' in response.http_request.headers) + self.assertEqual( + response.http_request.headers['User-Agent'], + "azsdk-python-storage-queue/12.0.0b1 Python/{} ({})".format( + platform.python_version(), + platform.platform())) + + await service.get_service_properties(raw_response_hook=callback) + + def test_user_agent_default(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_user_agent_default()) + + async def _test_user_agent_custom(self): + custom_app = "TestApp/v1.0" + service = QueueServiceClient( + self._get_queue_url(), credential=self.account_key, user_agent=custom_app) + + def callback(response): + self.assertTrue('User-Agent' in response.http_request.headers) + self.assertEqual( + response.http_request.headers['User-Agent'], + "TestApp/v1.0 azsdk-python-storage-queue/12.0.0b1 Python/{} ({})".format( + platform.python_version(), + platform.platform())) + + await service.get_service_properties(raw_response_hook=callback) + + def callback(response): + self.assertTrue('User-Agent' in response.http_request.headers) + self.assertEqual( + response.http_request.headers['User-Agent'], + "TestApp/v2.0 azsdk-python-storage-queue/12.0.0b1 Python/{} ({})".format( + platform.python_version(), + platform.platform())) + + await service.get_service_properties(raw_response_hook=callback, user_agent="TestApp/v2.0") + + def test_user_agent_custom(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_user_agent_custom()) + + async def _test_user_agent_append(self): + service = QueueServiceClient(self._get_queue_url(), credential=self.account_key) + + def callback(response): + self.assertTrue('User-Agent' in response.http_request.headers) + self.assertEqual( + response.http_request.headers['User-Agent'], + "azsdk-python-storage-queue/12.0.0b1 Python/{} ({}) customer_user_agent".format( + platform.python_version(), + platform.platform())) + + custom_headers = {'User-Agent': 'customer_user_agent'} + await service.get_service_properties(raw_response_hook=callback, headers=custom_headers) + + def test_user_agent_append(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_user_agent_append()) +# ------------------------------------------------------------------------------ +if __name__ == '__main__': + unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encodings_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_encodings_async.py new file mode 100644 index 000000000000..b6ca8c836b67 --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encodings_async.py @@ -0,0 +1,255 @@ +# coding: utf-8 + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +import pytest +import asyncio + +from azure.core.exceptions import HttpResponseError, DecodeError, ResourceExistsError +from azure.storage.queue import ( + TextBase64EncodePolicy, + TextBase64DecodePolicy, + BinaryBase64EncodePolicy, + BinaryBase64DecodePolicy, + TextXMLEncodePolicy, + TextXMLDecodePolicy, + NoEncodePolicy, + NoDecodePolicy +) + +from azure.storage.queue.aio import ( + QueueClient, + QueueServiceClient +) + +from queuetestcase import ( + QueueTestCase, + record, + TestMode +) + +# ------------------------------------------------------------------------------ +TEST_QUEUE_PREFIX = 'mytestqueue' + + +# ------------------------------------------------------------------------------ + +class StorageQueueEncodingTestAsync(QueueTestCase): + def setUp(self): + super(StorageQueueEncodingTestAsync, self).setUp() + + queue_url = self._get_queue_url() + credentials = self._get_shared_key_credential() + self.qsc = QueueServiceClient(account_url=queue_url, credential=credentials) + self.test_queues = [] + + def tearDown(self): + if not self.is_playback(): + loop = asyncio.get_event_loop() + for queue in self.test_queues: + try: + loop.run_until_complete(self.qsc.delete_queue(queue.queue_name)) + except: + pass + return super(StorageQueueEncodingTestAsync, self).tearDown() + + # --Helpers----------------------------------------------------------------- + def _get_queue_reference(self, prefix=TEST_QUEUE_PREFIX): + queue_name = self.get_resource_name(prefix) + queue = self.qsc.get_queue_client(queue_name) + self.test_queues.append(queue) + return queue + + async def _create_queue(self, prefix=TEST_QUEUE_PREFIX): + queue = self._get_queue_reference(prefix) + try: + created = await queue.create_queue() + except ResourceExistsError: + pass + return queue + + async def _validate_encoding(self, queue, message): + # Arrange + try: + created = await queue.create_queue() + except ResourceExistsError: + pass + + # Action. + await queue.enqueue_message(message) + + # Asserts + dequeued = None + async for m in queue.receive_messages(): + dequeued = m + self.assertEqual(message, dequeued.content) + + # -------------------------------------------------------------------------- + + async def _test_message_text_xml(self): + # Arrange. + message = u'' + queue = self.qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) + + # Asserts + await self._validate_encoding(queue, message) + + def test_message_text_xml(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_text_xml()) + + async def _test_message_text_xml_whitespace(self): + # Arrange. + message = u' mess\t age1\n' + queue = self.qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) + + # Asserts + await self._validate_encoding(queue, message) + + def test_message_text_xml_whitespace(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_text_xml_whitespace()) + + async def _test_message_text_xml_invalid_chars(self): + # Action. + queue = self._get_queue_reference() + message = u'\u0001' + + # Asserts + with self.assertRaises(HttpResponseError): + await queue.enqueue_message(message) + + def test_message_text_xml_invalid_chars(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_text_xml_invalid_chars()) + + async def _test_message_text_base64(self): + # Arrange. + queue_url = self._get_queue_url() + credentials = self._get_shared_key_credential() + queue = QueueClient( + queue_url=queue_url, + queue=self.get_resource_name(TEST_QUEUE_PREFIX), + credential=credentials, + message_encode_policy=TextBase64EncodePolicy(), + message_decode_policy=TextBase64DecodePolicy()) + + message = u'\u0001' + + # Asserts + await self._validate_encoding(queue, message) + + def test_message_text_base64(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_text_base64()) + + async def _test_message_bytes_base64(self): + # Arrange. + queue_url = self._get_queue_url() + credentials = self._get_shared_key_credential() + queue = QueueClient( + queue_url=queue_url, + queue=self.get_resource_name(TEST_QUEUE_PREFIX), + credential=credentials, + message_encode_policy=BinaryBase64EncodePolicy(), + message_decode_policy=BinaryBase64DecodePolicy()) + + message = b'xyz' + + # Asserts + await self._validate_encoding(queue, message) + + def test_message_bytes_base64(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_bytes_base64()) + + async def _test_message_bytes_fails(self): + # Arrange + queue = self._get_queue_reference() + + # Action. + with self.assertRaises(TypeError) as e: + message = b'xyz' + await queue.enqueue_message(message) + + # Asserts + self.assertTrue(str(e.exception).startswith('Message content must be text')) + + def test_message_bytes_fails(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_bytes_fails()) + + async def _test_message_text_fails(self): + # Arrange + queue_url = self._get_queue_url() + credentials = self._get_shared_key_credential() + queue = QueueClient( + queue_url=queue_url, + queue=self.get_resource_name(TEST_QUEUE_PREFIX), + credential=credentials, + message_encode_policy=BinaryBase64EncodePolicy(), + message_decode_policy=BinaryBase64DecodePolicy()) + + # Action. + with self.assertRaises(TypeError) as e: + message = u'xyz' + await queue.enqueue_message(message) + + # Asserts + self.assertTrue(str(e.exception).startswith('Message content must be bytes')) + + def test_message_text_fails(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_text_fails()) + + async def _test_message_base64_decode_fails(self): + # Arrange + queue_url = self._get_queue_url() + credentials = self._get_shared_key_credential() + queue = QueueClient( + queue_url=queue_url, + queue=self.get_resource_name(TEST_QUEUE_PREFIX), + credential=credentials, + message_encode_policy=TextXMLEncodePolicy(), + message_decode_policy=BinaryBase64DecodePolicy()) + try: + await queue.create_queue() + except ResourceExistsError: + pass + message = u'xyz' + await queue.enqueue_message(message) + + # Action. + with self.assertRaises(DecodeError) as e: + await queue.peek_messages() + + # Asserts + self.assertNotEqual(-1, str(e.exception).find('Message content is not valid base 64')) + + def test_message_base64_decode_fails(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_message_base64_decode_fails()) + +# ------------------------------------------------------------------------------ +if __name__ == '__main__': + unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py b/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py index 4196991ebd55..5cf64c7b7834 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- import unittest +import six from base64 import ( b64decode, ) @@ -19,7 +20,6 @@ from cryptography.hazmat.primitives.padding import PKCS7 from azure.core.exceptions import HttpResponseError, ResourceExistsError - from azure.storage.queue._shared import decode_base64_to_bytes from azure.storage.queue._shared.encryption import ( _ERROR_OBJECT_INVALID, @@ -54,6 +54,10 @@ # ------------------------------------------------------------------------------ +def _decode_base64_to_bytes(data): + if isinstance(data, six.text_type): + data = data.encode('utf-8') + return b64decode(data) class StorageQueueEncryptionTest(QueueTestCase): def setUp(self): diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encryption_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_encryption_async.py new file mode 100644 index 000000000000..2b7aaf46a257 --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encryption_async.py @@ -0,0 +1,608 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +import pytest +import asyncio +import six +from base64 import ( + b64decode, +) +from json import ( + loads, + dumps, +) + +from cryptography.hazmat import backends +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers.algorithms import AES +from cryptography.hazmat.primitives.ciphers.modes import CBC +from cryptography.hazmat.primitives.padding import PKCS7 + +from azure.core.exceptions import HttpResponseError, ResourceExistsError + +from azure.storage.queue._shared.encryption import ( + _ERROR_OBJECT_INVALID, + _WrappedContentKey, + _EncryptionAgent, + _EncryptionData, +) + +from azure.storage.queue import ( + VERSION, + BinaryBase64EncodePolicy, + BinaryBase64DecodePolicy, + NoEncodePolicy, + NoDecodePolicy +) + +from azure.storage.queue.aio import ( + QueueServiceClient, + QueueClient +) + +from encryption_test_helper import ( + KeyWrapper, + KeyResolver, + RSAKeyWrapper, +) +from queuetestcase import ( + QueueTestCase, + record, + TestMode, +) + +# ------------------------------------------------------------------------------ +TEST_QUEUE_PREFIX = 'encryptionqueue' + + +# ------------------------------------------------------------------------------ + +def _decode_base64_to_bytes(data): + if isinstance(data, six.text_type): + data = data.encode('utf-8') + return b64decode(data) + +class StorageQueueEncryptionTestAsync(QueueTestCase): + def setUp(self): + super(StorageQueueEncryptionTestAsync, self).setUp() + + queue_url = self._get_queue_url() + credentials = self._get_shared_key_credential() + self.qsc = QueueServiceClient(account_url=queue_url, credential=credentials) + self.test_queues = [] + + def tearDown(self): + if not self.is_playback(): + for queue in self.test_queues: + try: + self.qsc.delete_queue(queue.queue_name) + except: + pass + return super(StorageQueueEncryptionTestAsync, self).tearDown() + + # --Helpers----------------------------------------------------------------- + def _get_queue_reference(self, prefix=TEST_QUEUE_PREFIX): + queue_name = self.get_resource_name(prefix) + queue = self.qsc.get_queue_client(queue_name) + self.test_queues.append(queue) + return queue + + async def _create_queue(self, prefix=TEST_QUEUE_PREFIX): + queue = self._get_queue_reference(prefix) + try: + created = await queue.create_queue() + except ResourceExistsError: + pass + return queue + # -------------------------------------------------------------------------- + + async def _test_get_messages_encrypted_kek(self): + # Arrange + self.qsc.key_encryption_key = KeyWrapper('key1') + queue = await self._create_queue() + await queue.enqueue_message(u'encrypted_message_2') + + # Act + li = None + async for m in queue.receive_messages(): + li = m + + # Assert + self.assertEqual(li.content, u'encrypted_message_2') + + def test_get_messages_encrypted_kek(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_messages_encrypted_kek()) + + async def _test_get_messages_encrypted_resolver(self): + # Arrange + self.qsc.key_encryption_key = KeyWrapper('key1') + queue = await self._create_queue() + await queue.enqueue_message(u'encrypted_message_2') + key_resolver = KeyResolver() + key_resolver.put_key(self.qsc.key_encryption_key) + queue.key_resolver_function = key_resolver.resolve_key + queue.key_encryption_key = None # Ensure that the resolver is used + + # Act + li = None + async for m in queue.receive_messages(): + li = m + + # Assert + self.assertEqual(li.content, u'encrypted_message_2') + + def test_get_messages_encrypted_resolver(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_messages_encrypted_resolver()) + + async def _test_peek_messages_encrypted_kek(self): + # Arrange + self.qsc.key_encryption_key = KeyWrapper('key1') + queue = await self._create_queue() + await queue.enqueue_message(u'encrypted_message_3') + + # Act + li = await queue.peek_messages() + + # Assert + self.assertEqual(li[0].content, u'encrypted_message_3') + + def test_peek_messages_encrypted_kek(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_peek_messages_encrypted_kek()) + + async def _test_peek_messages_encrypted_resolver(self): + # Arrange + self.qsc.key_encryption_key = KeyWrapper('key1') + queue = await self._create_queue() + await queue.enqueue_message(u'encrypted_message_4') + key_resolver = KeyResolver() + key_resolver.put_key(self.qsc.key_encryption_key) + queue.key_resolver_function = key_resolver.resolve_key + queue.key_encryption_key = None # Ensure that the resolver is used + + # Act + li = await queue.peek_messages() + + # Assert + self.assertEqual(li[0].content, u'encrypted_message_4') + + def test_peek_messages_encrypted_resolver(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_peek_messages_encrypted_resolver()) + + async def _test_peek_messages_encrypted_kek_RSA(self): + + # We can only generate random RSA keys, so this must be run live or + # the playback test will fail due to a change in kek values. + if TestMode.need_recording_file(self.test_mode): + return + + # Arrange + self.qsc.key_encryption_key = RSAKeyWrapper('key2') + queue = await self._create_queue() + await queue.enqueue_message(u'encrypted_message_3') + + # Act + li = await queue.peek_messages() + + # Assert + self.assertEqual(li[0].content, u'encrypted_message_3') + + def test_peek_messages_encrypted_kek_RSA(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_peek_messages_encrypted_kek_RSA()) + + async def _test_update_encrypted_message(self): + # TODO: Recording doesn't work + if TestMode.need_recording_file(self.test_mode): + return + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + await queue.enqueue_message(u'Update Me') + + messages = [] + async for m in queue.receive_messages(): + messages.append(m) + list_result1 = messages[0] + list_result1.content = u'Updated' + + # Act + message = await queue.update_message(list_result1) + async for m in queue.receive_messages(): + messages.append(m) + list_result2 = messages[0] + + # Assert + self.assertEqual(u'Updated', list_result2.content) + + def test_update_encrypted_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_update_encrypted_message()) + + async def _test_update_encrypted_binary_message(self): + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + queue._config.message_encode_policy = BinaryBase64EncodePolicy() + queue._config.message_decode_policy = BinaryBase64DecodePolicy() + + binary_message = self.get_random_bytes(100) + await queue.enqueue_message(binary_message) + messages = [] + async for m in queue.receive_messages(): + messages.append(m) + list_result1 = messages[0] + + # Act + binary_message = self.get_random_bytes(100) + list_result1.content = binary_message + await queue.update_message(list_result1) + + async for m in queue.receive_messages(): + messages.append(m) + list_result2 = messages[0] + + # Assert + self.assertEqual(binary_message, list_result2.content) + + def test_update_encrypted_binary_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_update_encrypted_binary_message()) + + async def _test_update_encrypted_raw_text_message(self): + # TODO: Recording doesn't work + if TestMode.need_recording_file(self.test_mode): + return + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + queue._config.message_encode_policy = NoEncodePolicy() + queue._config.message_decode_policy = NoDecodePolicy() + + raw_text = u'Update Me' + await queue.enqueue_message(raw_text) + messages = [] + async for m in queue.receive_messages(): + messages.append(m) + list_result1 = messages[0] + + # Act + raw_text = u'Updated' + list_result1.content = raw_text + async for m in queue.receive_messages(): + messages.append(m) + list_result2 = messages[0] + + # Assert + self.assertEqual(raw_text, list_result2.content) + + def test_update_encrypted_raw_text_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_update_encrypted_raw_text_message()) + + async def _test_update_encrypted_json_message(self): + # TODO: Recording doesn't work + if TestMode.need_recording_file(self.test_mode): + return + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + queue._config.message_encode_policy = NoEncodePolicy() + queue._config.message_decode_policy = NoDecodePolicy() + + message_dict = {'val1': 1, 'val2': '2'} + json_text = dumps(message_dict) + await queue.enqueue_message(json_text) + messages = [] + async for m in queue.receive_messages(): + messages.append(m) + list_result1 = messages[0] + + # Act + message_dict['val1'] = 0 + message_dict['val2'] = 'updated' + json_text = dumps(message_dict) + list_result1.content = json_text + await queue.update_message(list_result1) + + async for m in queue.receive_messages(): + messages.append(m) + list_result2 = messages[0] + + # Assert + self.assertEqual(message_dict, loads(list_result2.content)) + + def test_update_encrypted_json_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_update_encrypted_json_message()) + + async def _test_invalid_value_kek_wrap(self): + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key.get_kid = None + + with self.assertRaises(AttributeError) as e: + await queue.enqueue_message(u'message') + + self.assertEqual(str(e.exception), _ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) + + queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key.get_kid = None + with self.assertRaises(AttributeError): + await queue.enqueue_message(u'message') + + queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key.wrap_key = None + with self.assertRaises(AttributeError): + await queue.enqueue_message(u'message') + + def test_invalid_value_kek_wrap(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_invalid_value_kek_wrap()) + + async def _test_missing_attribute_kek_wrap(self): + # Arrange + queue = await self._create_queue() + + valid_key = KeyWrapper('key1') + + # Act + invalid_key_1 = lambda: None # functions are objects, so this effectively creates an empty object + invalid_key_1.get_key_wrap_algorithm = valid_key.get_key_wrap_algorithm + invalid_key_1.get_kid = valid_key.get_kid + # No attribute wrap_key + queue.key_encryption_key = invalid_key_1 + with self.assertRaises(AttributeError): + await queue.enqueue_message(u'message') + + invalid_key_2 = lambda: None # functions are objects, so this effectively creates an empty object + invalid_key_2.wrap_key = valid_key.wrap_key + invalid_key_2.get_kid = valid_key.get_kid + # No attribute get_key_wrap_algorithm + queue.key_encryption_key = invalid_key_2 + with self.assertRaises(AttributeError): + await queue.enqueue_message(u'message') + + invalid_key_3 = lambda: None # functions are objects, so this effectively creates an empty object + invalid_key_3.get_key_wrap_algorithm = valid_key.get_key_wrap_algorithm + invalid_key_3.wrap_key = valid_key.wrap_key + # No attribute get_kid + queue.key_encryption_key = invalid_key_3 + with self.assertRaises(AttributeError): + await queue.enqueue_message(u'message') + + def test_missing_attribute_kek_wrap(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_missing_attribute_kek_wrap()) + + async def _test_invalid_value_kek_unwrap(self): + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + await queue.enqueue_message(u'message') + + # Act + queue.key_encryption_key.unwrap_key = None + with self.assertRaises(HttpResponseError): + await queue.peek_messages() + + queue.key_encryption_key.get_kid = None + with self.assertRaises(HttpResponseError): + await queue.peek_messages() + + def test_invalid_value_kek_unwrap(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_invalid_value_kek_unwrap()) + + async def _test_missing_attribute_kek_unrwap(self): + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + await queue.enqueue_message(u'message') + + # Act + valid_key = KeyWrapper('key1') + invalid_key_1 = lambda: None # functions are objects, so this effectively creates an empty object + invalid_key_1.unwrap_key = valid_key.unwrap_key + # No attribute get_kid + queue.key_encryption_key = invalid_key_1 + with self.assertRaises(HttpResponseError) as e: + await queue.peek_messages() + + self.assertEqual(str(e.exception), "Decryption failed.") + + invalid_key_2 = lambda: None # functions are objects, so this effectively creates an empty object + invalid_key_2.get_kid = valid_key.get_kid + # No attribute unwrap_key + queue.key_encryption_key = invalid_key_2 + with self.assertRaises(HttpResponseError): + await queue.peek_messages() + + def test_missing_attribute_kek_unrwap(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_missing_attribute_kek_unrwap()) + + async def _test_validate_encryption(self): + # Arrange + queue = await self._create_queue() + kek = KeyWrapper('key1') + queue.key_encryption_key = kek + await queue.enqueue_message(u'message') + + # Act + queue.key_encryption_key = None # Message will not be decrypted + li = await queue.peek_messages() + message = li[0].content + message = loads(message) + + encryption_data = message['EncryptionData'] + + wrapped_content_key = encryption_data['WrappedContentKey'] + wrapped_content_key = _WrappedContentKey( + wrapped_content_key['Algorithm'], + b64decode(wrapped_content_key['EncryptedKey'].encode(encoding='utf-8')), + wrapped_content_key['KeyId']) + + encryption_agent = encryption_data['EncryptionAgent'] + encryption_agent = _EncryptionAgent( + encryption_agent['EncryptionAlgorithm'], + encryption_agent['Protocol']) + + encryption_data = _EncryptionData( + b64decode(encryption_data['ContentEncryptionIV'].encode(encoding='utf-8')), + encryption_agent, + wrapped_content_key, + {'EncryptionLibrary': VERSION}) + + message = message['EncryptedMessageContents'] + content_encryption_key = kek.unwrap_key( + encryption_data.wrapped_content_key.encrypted_key, + encryption_data.wrapped_content_key.algorithm) + + # Create decryption cipher + backend = backends.default_backend() + algorithm = AES(content_encryption_key) + mode = CBC(encryption_data.content_encryption_IV) + cipher = Cipher(algorithm, mode, backend) + + # decode and decrypt data + decrypted_data = _decode_base64_to_bytes(message) + decryptor = cipher.decryptor() + decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) + + # unpad data + unpadder = PKCS7(128).unpadder() + decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) + + decrypted_data = decrypted_data.decode(encoding='utf-8') + + # Assert + self.assertEqual(decrypted_data, u'message') + + def test_validate_encryption(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_validate_encryption()) + + async def _test_put_with_strict_mode(self): + # Arrange + queue = await self._create_queue() + kek = KeyWrapper('key1') + queue.key_encryption_key = kek + queue.require_encryption = True + + await queue.enqueue_message(u'message') + queue.key_encryption_key = None + + # Assert + with self.assertRaises(ValueError) as e: + await queue.enqueue_message(u'message') + + self.assertEqual(str(e.exception), "Encryption required but no key was provided.") + + def test_put_with_strict_mode(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_put_with_strict_mode()) + + async def _test_get_with_strict_mode(self): + # Arrange + queue = await self._create_queue() + await queue.enqueue_message(u'message') + + queue.require_encryption = True + queue.key_encryption_key = KeyWrapper('key1') + with self.assertRaises(ValueError) as e: + messages = [] + async for m in queue.receive_messages(): + messages.append(m) + _ = messages[0] + self.assertEqual(str(e.exception), 'Message was not encrypted.') + + def test_get_with_strict_mode(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_get_with_strict_mode()) + + async def _test_encryption_add_encrypted_64k_message(self): + # Arrange + queue = await self._create_queue() + message = u'a' * 1024 * 64 + + # Act + await queue.enqueue_message(message) + + # Assert + queue.key_encryption_key = KeyWrapper('key1') + with self.assertRaises(HttpResponseError): + await queue.enqueue_message(message) + + def test_encryption_add_encrypted_64k_message(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_encryption_add_encrypted_64k_message()) + + async def _test_encryption_nonmatching_kid(self): + # Arrange + queue = await self._create_queue() + queue.key_encryption_key = KeyWrapper('key1') + await queue.enqueue_message(u'message') + + # Act + queue.key_encryption_key.kid = 'Invalid' + + # Assert + with self.assertRaises(HttpResponseError) as e: + messages = [] + async for m in queue.receive_messages(): + messages.append(m) + + self.assertEqual(str(e.exception), "Decryption failed.") + + def test_encryption_nonmatching_kid(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_encryption_nonmatching_kid()) + + +# ------------------------------------------------------------------------------ +if __name__ == '__main__': + unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_samples_authentication_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_samples_authentication_async.py new file mode 100644 index 000000000000..5c30bc073e5a --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_samples_authentication_async.py @@ -0,0 +1,121 @@ +# coding: utf-8 + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from datetime import datetime, timedelta +import pytest +import asyncio + +try: + import settings_real as settings +except ImportError: + import queue_settings_fake as settings + +from queuetestcase import ( + QueueTestCase, + TestMode, + record +) + + +class TestQueueAuthSamplesAsync(QueueTestCase): + url = "{}://{}.queue.core.windows.net".format( + settings.PROTOCOL, + settings.STORAGE_ACCOUNT_NAME + ) + + connection_string = settings.CONNECTION_STRING + shared_access_key = settings.STORAGE_ACCOUNT_KEY + active_directory_application_id = settings.ACTIVE_DIRECTORY_APPLICATION_ID + active_directory_application_secret = settings.ACTIVE_DIRECTORY_APPLICATION_SECRET + active_directory_tenant_id = settings.ACTIVE_DIRECTORY_TENANT_ID + + async def _test_auth_connection_string(self): + # Instantiate a QueueServiceClient using a connection string + # [START auth_from_connection_string] + from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(self.connection_string) + # [END auth_from_connection_string] + + # Get information for the Queue Service + properties = await queue_service.get_service_properties() + + assert properties is not None + + def test_auth_connection_string(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_auth_connection_string()) + + async def _test_auth_shared_key(self): + + # Instantiate a QueueServiceClient using a shared access key + # [START create_queue_service_client] + from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.url, credential=self.shared_access_key) + # [END create_queue_service_client] + # Get information for the Queue Service + properties = await queue_service.get_service_properties() + + assert properties is not None + + def test_auth_shared_key(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_auth_shared_key()) + + async def _test_auth_active_directory(self): + pytest.skip('pending azure identity') + + # Get a token credential for authentication + from azure.identity import ClientSecretCredential + token_credential = ClientSecretCredential( + self.active_directory_application_id, + self.active_directory_application_secret, + self.active_directory_tenant_id + ) + + # Instantiate a QueueServiceClient using a token credential + from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.url, credential=token_credential) + + # Get information for the Queue Service + properties = await queue_service.get_service_properties() + + assert properties is not None + + def test_auth_active_directory(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_auth_active_directory()) + + async def _test_auth_shared_access_signature(self): + # SAS URL is calculated from storage key, so this test runs live only + if TestMode.need_recording_file(self.test_mode): + return + + # Instantiate a QueueServiceClient using a connection string + from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(self.connection_string) + + # Create a SAS token to use for authentication of a client + sas_token = queue_service.generate_shared_access_signature( + resource_types="object", + permission="read", + expiry=datetime.utcnow() + timedelta(hours=1) + ) + + assert sas_token is not None + + def test_auth_shared_access_signature(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_auth_shared_access_signature()) diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_samples_hello_world_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_samples_hello_world_async.py new file mode 100644 index 000000000000..26ea87514a45 --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_samples_hello_world_async.py @@ -0,0 +1,42 @@ +# coding: utf-8 + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import asyncio + +try: + import settings_real as settings +except ImportError: + import queue_settings_fake as settings + +from queuetestcase import ( + QueueTestCase, + record, + TestMode +) + + +class TestQueueHelloWorldSamplesAsync(QueueTestCase): + + connection_string = settings.CONNECTION_STRING + + async def _test_create_client_with_connection_string(self): + # Instantiate the QueueServiceClient from a connection string + from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(self.connection_string) + + # Get queue service properties + properties = await queue_service.get_service_properties() + + assert properties is not None + + def test_create_client_with_connection_string(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_create_client_with_connection_string()) diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_service_properties_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_service_properties_async.py new file mode 100644 index 000000000000..2d85d9cd13a6 --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_service_properties_async.py @@ -0,0 +1,272 @@ +# coding: utf-8 + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +import pytest +import asyncio + +from azure.core.exceptions import HttpResponseError + +from azure.storage.queue.aio import ( + QueueServiceClient, + QueueClient, + Logging, + Metrics, + CorsRule, + RetentionPolicy +) + +from queuetestcase import ( + QueueTestCase, + record, + not_for_emulator, + TestMode +) + + +# ------------------------------------------------------------------------------ + + +class QueueServicePropertiesTest(QueueTestCase): + def setUp(self): + super(QueueServicePropertiesTest, self).setUp() + + url = self._get_queue_url() + credential = self._get_shared_key_credential() + self.qsc = QueueServiceClient(url, credential=credential) + + # --Helpers----------------------------------------------------------------- + def _assert_properties_default(self, prop): + self.assertIsNotNone(prop) + + self._assert_logging_equal(prop.logging, Logging()) + self._assert_metrics_equal(prop.hour_metrics, Metrics()) + self._assert_metrics_equal(prop.minute_metrics, Metrics()) + self._assert_cors_equal(prop.cors, list()) + + def _assert_logging_equal(self, log1, log2): + if log1 is None or log2 is None: + self.assertEqual(log1, log2) + return + + self.assertEqual(log1.version, log2.version) + self.assertEqual(log1.read, log2.read) + self.assertEqual(log1.write, log2.write) + self.assertEqual(log1.delete, log2.delete) + self._assert_retention_equal(log1.retention_policy, log2.retention_policy) + + def _assert_delete_retention_policy_equal(self, policy1, policy2): + if policy1 is None or policy2 is None: + self.assertEqual(policy1, policy2) + return + + self.assertEqual(policy1.enabled, policy2.enabled) + self.assertEqual(policy1.days, policy2.days) + + def _assert_static_website_equal(self, prop1, prop2): + if prop1 is None or prop2 is None: + self.assertEqual(prop1, prop2) + return + + self.assertEqual(prop1.enabled, prop2.enabled) + self.assertEqual(prop1.index_document, prop2.index_document) + self.assertEqual(prop1.error_document404_path, prop2.error_document404_path) + + def _assert_delete_retention_policy_not_equal(self, policy1, policy2): + if policy1 is None or policy2 is None: + self.assertNotEqual(policy1, policy2) + return + + self.assertFalse(policy1.enabled == policy2.enabled + and policy1.days == policy2.days) + + def _assert_metrics_equal(self, metrics1, metrics2): + if metrics1 is None or metrics2 is None: + self.assertEqual(metrics1, metrics2) + return + + self.assertEqual(metrics1.version, metrics2.version) + self.assertEqual(metrics1.enabled, metrics2.enabled) + self.assertEqual(metrics1.include_apis, metrics2.include_apis) + self._assert_retention_equal(metrics1.retention_policy, metrics2.retention_policy) + + def _assert_cors_equal(self, cors1, cors2): + if cors1 is None or cors2 is None: + self.assertEqual(cors1, cors2) + return + + self.assertEqual(len(cors1), len(cors2)) + + for i in range(0, len(cors1)): + rule1 = cors1[i] + rule2 = cors2[i] + self.assertEqual(len(rule1.allowed_origins), len(rule2.allowed_origins)) + self.assertEqual(len(rule1.allowed_methods), len(rule2.allowed_methods)) + self.assertEqual(rule1.max_age_in_seconds, rule2.max_age_in_seconds) + self.assertEqual(len(rule1.exposed_headers), len(rule2.exposed_headers)) + self.assertEqual(len(rule1.allowed_headers), len(rule2.allowed_headers)) + + def _assert_retention_equal(self, ret1, ret2): + self.assertEqual(ret1.enabled, ret2.enabled) + self.assertEqual(ret1.days, ret2.days) + + # --Test cases per service --------------------------------------- + + async def _test_queue_service_properties(self): + # Arrange + + # Act + resp = await self.qsc.set_service_properties( + logging=Logging(), + hour_metrics=Metrics(), + minute_metrics=Metrics(), + cors=list()) + + # Assert + self.assertIsNone(resp) + props = await self.qsc.get_service_properties() + self._assert_properties_default(props) + + def test_queue_service_properties(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_queue_service_properties()) + + # --Test cases per feature --------------------------------------- + + async def _test_set_logging(self): + # Arrange + logging = Logging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Act + await self.qsc.set_service_properties(logging=logging) + + # Assert + received_props = await self.qsc.get_service_properties() + self._assert_logging_equal(received_props.logging, logging) + + def test_set_logging(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_logging()) + + async def _test_set_hour_metrics(self): + # Arrange + hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Act + await self.qsc.set_service_properties(hour_metrics=hour_metrics) + + # Assert + received_props = await self.qsc.get_service_properties() + self._assert_metrics_equal(received_props.hour_metrics, hour_metrics) + + def test_set_hour_metrics(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_hour_metrics()) + + async def _test_set_minute_metrics(self): + # Arrange + minute_metrics = Metrics(enabled=True, include_apis=True, + retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Act + await self.qsc.set_service_properties(minute_metrics=minute_metrics) + + # Assert + received_props = await self.qsc.get_service_properties() + self._assert_metrics_equal(received_props.minute_metrics, minute_metrics) + + def test_set_minute_metrics(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_minute_metrics()) + + + async def _test_set_cors(self): + # Arrange + cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) + + allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] + allowed_methods = ['GET', 'PUT'] + max_age_in_seconds = 500 + exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] + allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] + cors_rule2 = CorsRule( + allowed_origins, + allowed_methods, + max_age_in_seconds=max_age_in_seconds, + exposed_headers=exposed_headers, + allowed_headers=allowed_headers) + + cors = [cors_rule1, cors_rule2] + + # Act + await self.qsc.set_service_properties(cors=cors) + + # Assert + received_props = await self.qsc.get_service_properties() + self._assert_cors_equal(received_props.cors, cors) + + def test_set_cors(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_set_cors()) + + # --Test cases for errors --------------------------------------- + async def _test_retention_no_days(self): + # Assert + self.assertRaises(ValueError, + RetentionPolicy, + True, None) + + def test_retention_no_days(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_retention_no_days()) + + async def _test_too_many_cors_rules(self): + # Arrange + cors = [] + for _ in range(0, 6): + cors.append(CorsRule(['www.xyz.com'], ['GET'])) + + # Assert + with self.assertRaises(HttpResponseError): + await self.qsc.set_service_properties() + + def test_too_many_cors_rules(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_too_many_cors_rules()) + + async def _test_retention_too_long(self): + # Arrange + minute_metrics = Metrics(enabled=True, include_apis=True, + retention_policy=RetentionPolicy(enabled=True, days=366)) + + # Assert + with self.assertRaises(HttpResponseError): + await self.qsc.set_service_properties() + + def test_retention_too_long(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_retention_too_long()) + +# ------------------------------------------------------------------------------ +if __name__ == '__main__': + unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_service_stats_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_service_stats_async.py new file mode 100644 index 000000000000..3a66d1a98241 --- /dev/null +++ b/sdk/storage/azure-storage-queue/tests/test_queue_service_stats_async.py @@ -0,0 +1,83 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +import asyncio + +from azure.storage.queue.aio import QueueServiceClient + +from queuetestcase import ( + QueueTestCase, + record, + TestMode +) + +SERVICE_UNAVAILABLE_RESP_BODY = 'unavailable ' + + +# --Test Class ----------------------------------------------------------------- +class QueueServiceStatsTestAsync(QueueTestCase): + # --Helpers----------------------------------------------------------------- + def _assert_stats_default(self, stats): + self.assertIsNotNone(stats) + self.assertIsNotNone(stats.geo_replication) + + self.assertEqual(stats.geo_replication.status, 'live') + self.assertIsNotNone(stats.geo_replication.last_sync_time) + + def _assert_stats_unavailable(self, stats): + self.assertIsNotNone(stats) + self.assertIsNotNone(stats.geo_replication) + + self.assertEqual(stats.geo_replication.status, 'unavailable') + self.assertIsNone(stats.geo_replication.last_sync_time) + + @staticmethod + def override_response_body_with_unavailable_status(response): + response.http_response.text = lambda: SERVICE_UNAVAILABLE_RESP_BODY + + # --Test cases per service --------------------------------------- + + async def _test_queue_service_stats_f(self): + # Arrange + url = self._get_queue_url() + credential = self._get_shared_key_credential() + qsc = QueueServiceClient(url, credential=credential) + + # Act + stats = await qsc.get_service_stats() + + # Assert + self._assert_stats_default(stats) + + def test_queue_service_stats_f(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_queue_service_stats_f()) + + async def _test_queue_service_stats_when_unavailable(self): + # Arrange + url = self._get_queue_url() + credential = self._get_shared_key_credential() + qsc = QueueServiceClient(url, credential=credential) + + # Act + stats = await qsc.get_service_stats( + raw_response_hook=self.override_response_body_with_unavailable_status) + + # Assert + self._assert_stats_unavailable(stats) + + def test_queue_service_stats_when_unavailable(self): + if TestMode.need_recording_file(self.test_mode): + return + loop = asyncio.get_event_loop() + loop.run_until_complete(self._test_queue_service_stats_when_unavailable()) +# ------------------------------------------------------------------------------ +if __name__ == '__main__': + unittest.main()