Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Improve typings in multipart
  • Loading branch information
kornicameister committed Feb 23, 2019
commit aa6ceea2c9034b7ec3bedffa256fd9ab8a7d9b47
4 changes: 4 additions & 0 deletions CHANGES/3621.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Improve typing annotations for multipart.

Use `async for` instead of `while` loop for
reading full multipart data in `web_request.Request#post` functionality.
21 changes: 12 additions & 9 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class BodyPartReader:
def __init__(
self,
boundary: bytes,
headers: Mapping[str, Optional[str]],
headers: Mapping[str, str],
content: StreamReader,
*,
_newline: bytes = b'\r\n'
Expand Down Expand Up @@ -443,7 +443,7 @@ def decode(self, data: bytes) -> bytes:
return data

def _decode_content(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_ENCODING]).lower()
encoding = self.headers[CONTENT_ENCODING].lower()

if encoding == 'deflate':
return zlib.decompress(data, -zlib.MAX_WBITS)
Expand All @@ -455,7 +455,7 @@ def _decode_content(self, data: bytes) -> bytes:
raise RuntimeError('unknown content encoding: {}'.format(encoding))

def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_TRANSFER_ENCODING]).lower()
encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower()

if encoding == 'base64':
return base64.b64decode(data)
Expand Down Expand Up @@ -539,15 +539,15 @@ def __init__(
self._boundary = ('--' + self._get_boundary()).encode()
self._newline = _newline
self._content = content
self._last_part = None
self._last_part = None # type: Optional[Union[MultipartReader, BodyPartReader]]
self._at_eof = False
self._at_bof = True
self._unread = [] # type: List[bytes]

def __aiter__(self) -> 'MultipartReader':
return self

async def __anext__(self) -> Any:
async def __anext__(self) -> Union['MultipartReader', BodyPartReader]:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
Expand All @@ -569,19 +569,19 @@ def at_eof(self) -> bool:
"""
return self._at_eof

async def next(self) -> Any:
async def next(self) -> Optional[Union['MultipartReader', BodyPartReader]]:
"""Emits the next multipart body part."""
# So, if we're at BOF, we need to skip till the boundary.
if self._at_eof:
return
return None
await self._maybe_release_last_part()
if self._at_bof:
await self._read_until_first_boundary()
self._at_bof = False
else:
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return
return None
self._last_part = await self.fetch_next_part()
return self._last_part

Expand All @@ -598,7 +598,10 @@ async def fetch_next_part(self) -> Any:
headers = await self._read_headers()
return self._get_part_reader(headers)

def _get_part_reader(self, headers: 'CIMultiDictProxy[str]') -> Any:
def _get_part_reader(
self,
headers: 'CIMultiDictProxy[str]',
) -> Union['MultipartReader', BodyPartReader]:
"""Dispatches the response by the `Content-Type` header, returning
suitable reader instance.

Expand Down
73 changes: 38 additions & 35 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .abc import AbstractStreamWriter
from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel
from .http_parser import RawRequestMessage
from .multipart import MultipartReader
from .multipart import MultipartReader, BodyPartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
Expand Down Expand Up @@ -608,46 +608,49 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
multipart = await self.multipart()
max_size = self._client_max_size

field = await multipart.next()
while field is not None:
async for field in multipart:
size = 0
content_type = field.headers.get(hdrs.CONTENT_TYPE)

if field.filename:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
field_content_type = field.headers.get(hdrs.CONTENT_TYPE)

if isinstance(field, BodyPartReader):
if field.filename:
assert field_content_type is not None, 'Cannot read file without knowing what it is'
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(
field.name,
field.filename,
cast(io.BufferedReader, tmp),
field_content_type,
CIMultiDictProxy(CIMultiDict(**field.headers)),
)
out.add(field.name, ff)
else:
value = await field.read(decode=True)
if content_type is None or \
content_type.startswith('text/'):
charset = field.get_charset(default='utf-8')
value = value.decode(charset)
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
content_type, field.headers)
out.add(field.name, ff)
else:
value = await field.read(decode=True)
if content_type is None or \
content_type.startswith('text/'):
charset = field.get_charset(default='utf-8')
value = value.decode(charset)
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)

field = await multipart.next()
else:
data = await self.read()
if data:
Expand Down