Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Latest shared code
  • Loading branch information
annatisch committed Jul 19, 2019
commit 3db3e75edbc4ff39aa2fde72a6c4c8ef7433768b
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def _add_authorization_header(self, request, string_to_sign):
raise _wrap_exception(ex, AzureSigningError)

def on_request(self, request, **kwargs):
if not 'content-type' in request.http_request.headers:
request.http_request.headers['content-type'] = 'application/xml; charset=utf-8'

string_to_sign = \
self._get_verb(request) + \
self._get_headers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,15 @@
BearerTokenCredentialPolicy,
AsyncRedirectPolicy)

from .constants import STORAGE_OAUTH_SCOPE
from .constants import STORAGE_OAUTH_SCOPE, DEFAULT_SOCKET_TIMEOUT
from .authentication import SharedKeyCredentialPolicy
from .base_client import (
StorageAccountHostsMixin,
parse_query,
is_credential_sastoken,
format_shared_key_credential,
create_configuration,
parse_connection_str)
from .base_client import create_configuration
from .policies import (
StorageContentValidation,
StorageRequestHook,
StorageHosts,
QueueMessagePolicy)
from .policies_async import ExponentialRetry, AsyncStorageResponseHook
from .policies_async import AsyncStorageResponseHook


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,6 +59,8 @@ def _create_pipeline(self, credential, **kwargs):
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))

if 'connection_timeout' not in kwargs:
kwargs['connection_timeout'] = DEFAULT_SOCKET_TIMEOUT[0]
config = kwargs.get('_configuration') or create_configuration(**kwargs)
if kwargs.get('_pipeline'):
return config, kwargs['_pipeline']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from azure.core.exceptions import HttpResponseError

from .models import ModifiedAccessConditions
from .request_handlers import validate_and_format_range_headers
from .response_handlers import process_storage_error, parse_length_from_content_range
from .encryption import decrypt_blob
Expand Down Expand Up @@ -212,7 +211,7 @@ def _write_to_stream(self, chunk_data, chunk_start):
self.stream.write(chunk_data)


class StorageStreamDownloader(object):
class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes
"""A streaming object to download from Azure Storage.

The stream downloader can iterated, or download to open file or stream
Expand Down Expand Up @@ -294,14 +293,14 @@ def __iter__(self):
# Use the length unless it is over the end of the file
data_end = min(self.file_size, self.length + 1)

downloader = SequentialBlobChunkDownloader(
downloader = SequentialChunkDownloader(
service=self.service,
total_size=self.download_size,
chunk_size=self.config.max_chunk_get_size,
current_progress=self.first_get_size,
start_range=self.initial_range[1] + 1, # start where the first download ended
end_range=data_end,
stream=stream,
stream=None,
validate_content=self.validate_content,
encryption_options=self.encryption_options,
use_location=self.location_mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import sys
import asyncio
from io import BytesIO
from itertools import islice

from azure.core.exceptions import HttpResponseError

from .models import ModifiedAccessConditions
from .request_handlers import validate_and_format_range_headers
from .response_handlers import process_storage_error, parse_length_from_content_range
from .encryption import decrypt_blob
Expand All @@ -20,9 +20,7 @@
async def process_content(data, start_offset, end_offset, encryption):
if data is None:
raise ValueError("Response cannot be None.")
content = b""
async for chunk in data:
content += chunk
content = data.response.body
if encryption.get('key') is not None or encryption.get('resolver') is not None:
try:
return decrypt_blob(
Expand All @@ -41,7 +39,7 @@ async def process_content(data, start_offset, end_offset, encryption):
return content


class _AsyncChunkDownloader(object):
class _AsyncChunkDownloader(object): # pylint: disable=too-many-instance-attributes

def __init__(
self, service=None,
Expand All @@ -51,6 +49,7 @@ def __init__(
start_range=None,
end_range=None,
stream=None,
parallel=None,
validate_content=None,
encryption_options=None,
**kwargs):
Expand All @@ -65,6 +64,12 @@ def __init__(

# the destination that we will write to
self.stream = stream
self.stream_lock = asyncio.Lock() if parallel else None
self.progress_lock = asyncio.Lock() if parallel else None

# for a parallel download, the stream is always seekable, so we note down the current position
# in order to seek to the right place when out-of-order chunks come in
self.stream_start = stream.tell() if parallel else None

# download progress so far
self.progress_total = current_progress
Expand Down Expand Up @@ -95,19 +100,25 @@ async def process_chunk(self, chunk_start):
length = chunk_end - chunk_start
if length > 0:
await self._write_to_stream(chunk_data, chunk_start)
self._update_progress(length)
await self._update_progress(length)

async def yield_chunk(self, chunk_start):
chunk_start, chunk_end = self._calculate_range(chunk_start)
return await self._download_chunk(chunk_start, chunk_end)

async def _update_progress(self, length):
async with self.progress_lock:
if self.progress_lock:
async with self.progress_lock:
self.progress_total += length
else:
self.progress_total += length

async def _write_to_stream(self, chunk_data, chunk_start):
async with self.stream_lock:
self.stream.seek(self.stream_start + (chunk_start - self.start_index))
if self.stream_lock:
async with self.stream_lock:
self.stream.seek(self.stream_start + (chunk_start - self.start_index))
self.stream.write(chunk_data)
else:
self.stream.write(chunk_data)

async def _download_chunk(self, chunk_start, chunk_end):
Expand Down Expand Up @@ -138,7 +149,7 @@ async def _download_chunk(self, chunk_start, chunk_end):

return chunk_data

class StorageStreamDownloader(object):
class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes
"""A streaming object to download from Azure Storage.

The stream downloader can iterated, or download to open file or stream
Expand All @@ -152,7 +163,6 @@ def __init__(
length=None,
validate_content=None,
encryption_options=None,
extra_properties=None,
**kwargs):
self.service = service
self.config = config
Expand All @@ -163,6 +173,9 @@ def __init__(
self.request_options = kwargs
self.location_mode = None
self._download_complete = False
self._current_content = None
self._iter_downloader = None
self._iter_chunks = None

# The service only provides transactional MD5s for chunks under 4MB.
# If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first
Expand All @@ -177,10 +190,62 @@ def __init__(

self.initial_range, self.initial_offset = process_range_and_offset(
initial_request_start, initial_request_end, self.length, self.encryption_options)

self.download_size = None
self.file_size = None
self.response = self._initial_request()
self.response = None
self.properties = None

def __len__(self):
return self.download_size

def __iter__(self):
raise TypeError("Async stream must be iterated asynchronously.")

def __aiter__(self):
return self

async def __anext__(self):
"""Iterate through responses."""
if self._current_content is None:
if self.download_size == 0:
self._current_content = b""
else:
self._current_content = await process_content(
self.response, self.initial_offset[0], self.initial_offset[1], self.encryption_options)
if not self._download_complete:
data_end = self.file_size
if self.length is not None:
# Use the length unless it is over the end of the file
data_end = min(self.file_size, self.length + 1)
self._iter_downloader = _AsyncChunkDownloader(
service=self.service,
total_size=self.download_size,
chunk_size=self.config.max_chunk_get_size,
current_progress=self.first_get_size,
start_range=self.initial_range[1] + 1, # start where the first download ended
end_range=data_end,
stream=None,
parallel=False,
validate_content=self.validate_content,
encryption_options=self.encryption_options,
use_location=self.location_mode,
**self.request_options)
self._iter_chunks = self._iter_downloader.get_chunk_offsets()
elif self._download_complete:
raise StopAsyncIteration("Download complete")
else:
try:
chunk = next(self._iter_chunks)
except StopIteration:
raise StopAsyncIteration("DownloadComplete")
self._current_content = await self._iter_downloader.yield_chunk(chunk)

return self._current_content

async def setup(self, extra_properties=None):
if self.response:
raise ValueError("Download stream already initialized.")
self.response = await self._initial_request()
self.properties = self.response.properties

# Set the content length to the download size instead of the size of
Expand All @@ -200,49 +265,7 @@ def __init__(
# TODO: Set to the stored MD5 when the service returns this
self.properties.content_md5 = None

def __len__(self):
return self.download_size

def __iter__(self):
raise TypeError("Async stream must be iterated asynchronously.")

def __aiter__(self):
return self._async_data_iterator()

async def _async_data_iterator(self):
if self.download_size == 0:
content = b""
else:
content = process_content(
self.response, self.initial_offset[0], self.initial_offset[1], self.encryption_options)

if content is not None:
yield content
if self._download_complete:
return

data_end = self.file_size
if self.length is not None:
# Use the length unless it is over the end of the file
data_end = min(self.file_size, self.length + 1)

downloader = _AsyncChunkDownloader(
service=self.service,
total_size=self.download_size,
chunk_size=self.config.max_chunk_get_size,
current_progress=self.first_get_size,
start_range=self.initial_range[1] + 1, # start where the first download ended
end_range=data_end,
stream=stream,
validate_content=self.validate_content,
encryption_options=self.encryption_options,
use_location=self.location_mode,
**self.request_options)

for chunk in downloader.get_chunk_offsets():
yield await downloader.yield_chunk(chunk)

def _initial_request(self):
async def _initial_request(self):
range_header, range_validation = validate_and_format_range_headers(
self.initial_range[0],
self.initial_range[1],
Expand All @@ -251,7 +274,7 @@ def _initial_request(self):
check_content_md5=self.validate_content)

try:
location_mode, response = self.service.download(
location_mode, response = await self.service.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
Expand Down Expand Up @@ -280,7 +303,7 @@ def _initial_request(self):
# request a range, do a regular get request in order to get
# any properties.
try:
_, response = self.service.download(
_, response = await self.service.download(
validate_content=self.validate_content,
data_stream_total=0,
download_stream_current=0,
Expand All @@ -303,7 +326,6 @@ def _initial_request(self):
self.request_options['modified_access_conditions'].if_match = response.properties.etag
else:
self._download_complete = True

return response

async def content_as_bytes(self, max_connections=1):
Expand Down Expand Up @@ -341,8 +363,12 @@ async def download_to_stream(self, stream, max_connections=1):
:returns: The properties of the downloaded file.
:rtype: Any
"""
if self._iter_downloader:
raise ValueError("Stream is currently being iterated.")

# the stream must be seekable if parallel download is required
if max_connections > 1:
parallel = max_connections > 1
if parallel:
error_message = "Target stream handle must be seekable."
if sys.version_info >= (3,) and not stream.seekable():
raise ValueError(error_message)
Expand All @@ -355,7 +381,7 @@ async def download_to_stream(self, stream, max_connections=1):
if self.download_size == 0:
content = b""
else:
content = process_content(
content = await process_content(
self.response, self.initial_offset[0], self.initial_offset[1], self.encryption_options)

# Write the content to the user stream
Expand All @@ -377,6 +403,7 @@ async def download_to_stream(self, stream, max_connections=1):
start_range=self.initial_range[1] + 1, # start where the first download ended
end_range=data_end,
stream=stream,
parallel=parallel,
validate_content=self.validate_content,
encryption_options=self.encryption_options,
use_location=self.location_mode,
Expand All @@ -387,7 +414,7 @@ async def download_to_stream(self, stream, max_connections=1):
asyncio.ensure_future(downloader.process_chunk(d))
for d in islice(dl_tasks, 0, max_connections)
]
while True:
while running_futures:
# Wait for some download to finish before adding a new one
_done, running_futures = await asyncio.wait(
running_futures, return_when=asyncio.FIRST_COMPLETED)
Expand All @@ -398,6 +425,7 @@ async def download_to_stream(self, stream, max_connections=1):
else:
running_futures.add(asyncio.ensure_future(downloader.process_chunk(next_chunk)))

# Wait for the remaining downloads to finish
await asyncio.wait(running_futures)
if running_futures:
# Wait for the remaining downloads to finish
await asyncio.wait(running_futures)
return self.properties
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def generate_blob_encryption_data(key_encryption_key):


def decrypt_blob(require_encryption, key_encryption_key, key_resolver,
content, start_offset, end_offset, response_headers):
content, start_offset, end_offset, response_headers):
'''
Decrypts the given blob contents and returns only the requested range.

Expand Down
Loading