diff --git a/CHANGES/3621.bugfix b/CHANGES/3621.bugfix new file mode 100644 index 00000000000..925916f38a5 --- /dev/null +++ b/CHANGES/3621.bugfix @@ -0,0 +1,4 @@ +Improve typing annotations for multipart. + +Use `async for` instead of `while` loop for +reading full multipart data in `web_request.Request#post` functionality. diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 02e9f0065a8..b6d28a1d429 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -240,7 +240,7 @@ class BodyPartReader: def __init__( self, boundary: bytes, - headers: Mapping[str, Optional[str]], + headers: Mapping[str, str], content: StreamReader, *, _newline: bytes = b'\r\n' @@ -443,7 +443,7 @@ def decode(self, data: bytes) -> bytes: return data def _decode_content(self, data: bytes) -> bytes: - encoding = cast(str, self.headers[CONTENT_ENCODING]).lower() + encoding = self.headers[CONTENT_ENCODING].lower() if encoding == 'deflate': return zlib.decompress(data, -zlib.MAX_WBITS) @@ -455,7 +455,7 @@ def _decode_content(self, data: bytes) -> bytes: raise RuntimeError('unknown content encoding: {}'.format(encoding)) def _decode_content_transfer(self, data: bytes) -> bytes: - encoding = cast(str, self.headers[CONTENT_TRANSFER_ENCODING]).lower() + encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower() if encoding == 'base64': return base64.b64decode(data) @@ -547,7 +547,7 @@ def __init__( def __aiter__(self) -> 'MultipartReader': return self - async def __anext__(self) -> Any: + async def __anext__(self) -> Union['MultipartReader', BodyPartReader]: part = await self.next() if part is None: raise StopAsyncIteration # NOQA @@ -569,11 +569,11 @@ def at_eof(self) -> bool: """ return self._at_eof - async def next(self) -> Any: + async def next(self) -> Optional[Union['MultipartReader', BodyPartReader]]: """Emits the next multipart body part.""" # So, if we're at BOF, we need to skip till the boundary. if self._at_eof: - return + return None await self._maybe_release_last_part() if self._at_bof: await self._read_until_first_boundary() @@ -581,7 +581,7 @@ async def next(self) -> Any: else: await self._read_boundary() if self._at_eof: # we just read the last boundary, nothing to do there - return + return None self._last_part = await self.fetch_next_part() return self._last_part @@ -598,7 +598,10 @@ async def fetch_next_part(self) -> Any: headers = await self._read_headers() return self._get_part_reader(headers) - def _get_part_reader(self, headers: 'CIMultiDictProxy[str]') -> Any: + def _get_part_reader( + self, + headers: 'CIMultiDictProxy[str]', + ) -> Union['MultipartReader', BodyPartReader]: """Dispatches the response by the `Content-Type` header, returning suitable reader instance. diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 0534f7e265f..a3ea916ebe1 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -32,7 +32,7 @@ from .abc import AbstractStreamWriter from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel from .http_parser import RawRequestMessage -from .multipart import MultipartReader +from .multipart import BodyPartReader, MultipartReader from .streams import EmptyStreamReader, StreamReader from .typedefs import ( DEFAULT_JSON_DECODER, @@ -608,46 +608,50 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]': multipart = await self.multipart() max_size = self._client_max_size - field = await multipart.next() - while field is not None: + async for field in multipart: size = 0 - content_type = field.headers.get(hdrs.CONTENT_TYPE) - - if field.filename: - # store file in temp file - tmp = tempfile.TemporaryFile() - chunk = await field.read_chunk(size=2**16) - while chunk: - chunk = field.decode(chunk) - tmp.write(chunk) - size += len(chunk) + field_content_type = field.headers.get(hdrs.CONTENT_TYPE) + + if isinstance(field, BodyPartReader): + if field.filename: + assert field_content_type is not None, \ + 'Cannot read file without knowing what it is' + # store file in temp file + tmp = tempfile.TemporaryFile() + chunk = await field.read_chunk(size=2**16) + while chunk: + chunk = field.decode(chunk) + tmp.write(chunk) + size += len(chunk) + if 0 < max_size < size: + raise HTTPRequestEntityTooLarge( + max_size=max_size, + actual_size=size + ) + chunk = await field.read_chunk(size=2**16) + tmp.seek(0) + + ff = FileField( + field.name, + field.filename, + cast(io.BufferedReader, tmp), + field_content_type, + CIMultiDictProxy(CIMultiDict(**field.headers)), + ) + out.add(field.name, ff) + else: + value = await field.read(decode=True) + if content_type is None or \ + content_type.startswith('text/'): + charset = field.get_charset(default='utf-8') + value = value.decode(charset) + out.add(field.name, value) + size += len(value) if 0 < max_size < size: raise HTTPRequestEntityTooLarge( max_size=max_size, actual_size=size ) - chunk = await field.read_chunk(size=2**16) - tmp.seek(0) - - ff = FileField(field.name, field.filename, - cast(io.BufferedReader, tmp), - content_type, field.headers) - out.add(field.name, ff) - else: - value = await field.read(decode=True) - if content_type is None or \ - content_type.startswith('text/'): - charset = field.get_charset(default='utf-8') - value = value.decode(charset) - out.add(field.name, value) - size += len(value) - if 0 < max_size < size: - raise HTTPRequestEntityTooLarge( - max_size=max_size, - actual_size=size - ) - - field = await multipart.next() else: data = await self.read() if data: