diff --git a/aiohttp/web.py b/aiohttp/web.py index 7022d109c91..fc3828f444d 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1,11 +1,13 @@ import asyncio import binascii import collections +import mimetypes import cgi import http.cookies import io import json import re +import os from urllib.parse import urlsplit, parse_qsl, unquote @@ -893,6 +895,53 @@ def add_route(self, method, path, handler): compiled = re.compile('^' + pattern + '$') self._urls.append(Entry(compiled, method, handler)) + def _static_file_handler_maker(self, path): + @asyncio.coroutine + def _handler(request): + resp = StreamResponse(request) + filename = request.match_info['filename'] + filepath = os.path.join(path, filename) + if '..' in filename: + raise HTTPNotFound(request) + if not os.path.exists(filepath) or os.path.isdir(filepath): + raise HTTPNotFound(request) + + ct = mimetypes.guess_type(filename)[0] + if not ct: + ct = 'application/octet-stream' + resp.content_type = ct + + resp.headers['transfer-encoding'] = 'chunked' + resp.send_headers() + + with open(filepath, 'rb') as f: + chunk = f.read(1024) + while chunk: + resp.write(chunk) + chunk = f.read(1024) + + yield from resp.write_eof() + return resp + + return _handler + + def add_static(self, prefix, path): + """ + Adds static files view + :param prefix - url prefix + :param path - folder with files + """ + assert prefix.startswith('/') + assert os.path.exists(path), 'Path does not exist' + method = 'GET' + suffix = r'(?P.*)' # match everything after static prefix + if not prefix.endswith('/'): + prefix += '/' + compiled = re.compile('^' + prefix + suffix + '$') + self._urls.append(Entry( + compiled, method, self._static_file_handler_maker(path) + )) + ############################################################ # Application implementation diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 7a0c4c1fe00..ea3138a0fd4 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -3,6 +3,7 @@ import os.path import socket import unittest +import tempfile from aiohttp import web, request, FormData @@ -23,9 +24,10 @@ def find_unused_port(self): return port @asyncio.coroutine - def create_server(self, method, path, handler): + def create_server(self, method, path, handler=None): app = web.Application(loop=self.loop, debug=True) - app.router.add_route(method, path, handler) + if handler: + app.router.add_route(method, path, handler) port = self.find_unused_port() srv = yield from self.loop.create_server(app.make_handler, @@ -236,3 +238,32 @@ def go(): self.assertEqual(200, resp.status) self.loop.run_until_complete(go()) + + def test_static_file(self): + + @asyncio.coroutine + def go(tmpdirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', tmpdirname) + + resp = yield from request('GET', url, loop=self.loop) + self.assertEqual(200, resp.status) + txt = yield from resp.text() + self.assertEqual('file content', txt) + ct = resp.headers['CONTENT-TYPE'] + self.assertEqual('application/octet-stream', ct) + + resp = yield from request('GET', url+'fake', loop=self.loop) + self.assertEqual(404, resp.status) + resp = yield from request('GET', url+'/../../', loop=self.loop) + self.assertEqual(404, resp.status) + + with tempfile.TemporaryDirectory() as tmpdirname: + with tempfile.NamedTemporaryFile(dir=tmpdirname) as fp: + filename = os.path.basename(fp.name) + fp.write(b'file content') + fp.flush() + fp.seek(0) + self.loop.run_until_complete(go(tmpdirname, filename))