From 8e079da550c2bfcbe98a44ca47b79097945a5111 Mon Sep 17 00:00:00 2001 From: Andrey Falaleev Date: Thu, 27 Oct 2022 16:46:29 +0400 Subject: [PATCH 1/3] NDEV-813 add readonly mode for proxy --- proxy/common_neon/config.py | 6 ++++++ proxy/neon_rpc_api_model/neon_rpc_api_worker.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/proxy/common_neon/config.py b/proxy/common_neon/config.py index 47d4181e2..ceba601cf 100644 --- a/proxy/common_neon/config.py +++ b/proxy/common_neon/config.py @@ -25,6 +25,7 @@ def __init__(self): self._recheck_resource_after_uses_cnt = self._env_int("RECHECK_RESOURCE_AFTER_USES_CNT", 10, 60) self._retry_on_fail = self._env_int("RETRY_ON_FAIL", 1, 10) self._enable_private_api = self._env_bool("ENABLE_PRIVATE_API", False) + self._enable_send_tx_api = self._env_bool("ENABLE_SEND_TX_API", True) self._use_earliest_block_if_0_passed = self._env_bool("USE_EARLIEST_BLOCK_IF_0_PASSED", False) self._account_permission_update_int = self._env_int("ACCOUNT_PERMISSION_UPDATE_INT", 10, 60 * 5) self._allow_underpriced_tx_wo_chainid = self._env_bool("ALLOW_UNDERPRICED_TX_WITHOUT_CHAINID", False) @@ -148,6 +149,10 @@ def retry_on_fail(self) -> int: def enable_private_api(self) -> bool: return self._enable_private_api + @property + def enable_send_tx_api(self) -> bool: + return self._enable_send_tx_api + @property def use_earliest_block_if_0_passed(self) -> bool: return self._use_earliest_block_if_0_passed @@ -298,6 +303,7 @@ def __str__(self): f"RECHECK_RESOURCE_AFTER_USES_CNT: {self.recheck_resource_after_uses_cnt}", f"RETRY_ON_FAIL: {self.retry_on_fail}", f"ENABLE_PRIVATE_API: {self.enable_private_api}", + f"ENABLE_SEND_TX_API: {self.enable_send_tx_api}", f"USE_EARLIEST_BLOCK_IF_0_PASSED: {self.use_earliest_block_if_0_passed}", f"ACCOUNT_PERMISSION_UPDATE_INT: {self.account_permission_update_int}", f"ALLOW_UNDERPRICED_TX_WITHOUT_CHAINID: {self.allow_underpriced_tx_wo_chainid}", diff --git a/proxy/neon_rpc_api_model/neon_rpc_api_worker.py b/proxy/neon_rpc_api_model/neon_rpc_api_worker.py index cdfa303da..117fabe3a 100644 --- a/proxy/neon_rpc_api_model/neon_rpc_api_worker.py +++ b/proxy/neon_rpc_api_model/neon_rpc_api_worker.py @@ -785,16 +785,18 @@ def is_allowed_api(self, method_name: str) -> bool: f'Neon EVM {self.web3_clientVersion()}' ) - if self._config.enable_private_api: - return True + if method_name == 'eth_sendRawTransaction': + return self._config.enable_send_tx_api - private_method_list = ( + private_method_set = { "eth_accounts", "eth_sign", "eth_sendTransaction", "eth_signTransaction", - ) + } + + if method_name in private_method_set: + if (not self._config.enable_send_tx_api) or (not self._config.enable_private_api): + return False - if method_name in private_method_list: - return False return True From 2cbbe8262cd5ea8142c552f20db9909dad7d9f14 Mon Sep 17 00:00:00 2001 From: Andrey Falaleev Date: Thu, 27 Oct 2022 17:53:53 +0400 Subject: [PATCH 2/3] Sometime Solana returns null in logMessages --- proxy/common_neon/solana_neon_tx_receipt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/common_neon/solana_neon_tx_receipt.py b/proxy/common_neon/solana_neon_tx_receipt.py index 2c044c0ea..1b9172605 100644 --- a/proxy/common_neon/solana_neon_tx_receipt.py +++ b/proxy/common_neon/solana_neon_tx_receipt.py @@ -462,7 +462,7 @@ def from_tx_meta(tx_meta: SolTxMetaInfo) -> SolTxReceiptInfo: _ix_list = msg['instructions'] meta = tx_meta.tx['meta'] - log_msg_list = meta['logMessages'] + log_msg_list = meta.get('logMessages', []) _inner_ix_list = meta['innerInstructions'] _account_key_list: List[str] = msg['accountKeys'] From 68bb525a8a19d03b89e97dc27ccfb509e3f22214 Mon Sep 17 00:00:00 2001 From: Andrey Falaleev Date: Fri, 28 Oct 2022 01:33:52 +0400 Subject: [PATCH 3/3] NDEV-814 update core libraries to upstream --- Makefile | 2 +- README.md | 7 +- proxy/__init__.py | 11 +- proxy/common/_version.py | 40 + proxy/common/backports.py | 117 ++ proxy/common/constants.py | 99 +- proxy/common/flag.py | 395 +++++++ proxy/common/flags.py | 541 --------- proxy/common/logger.py | 50 + proxy/common/pki.py | 157 ++- proxy/common/plugins.py | 105 ++ proxy/common/types.py | 32 +- proxy/common/utils.py | 224 +++- proxy/common/version.py | 6 +- proxy/common_neon/solana_tx_error_parser.py | 7 +- proxy/core/__init__.py | 4 + proxy/core/acceptor/__init__.py | 7 +- proxy/core/acceptor/acceptor.py | 309 +++-- proxy/core/acceptor/pool.py | 207 ++-- proxy/core/base/__init__.py | 20 + proxy/core/base/tcp_server.py | 242 ++++ proxy/core/base/tcp_tunnel.py | 113 ++ proxy/core/base/tcp_upstream.py | 107 ++ proxy/core/connection/__init__.py | 11 +- proxy/core/connection/client.py | 40 +- proxy/core/connection/connection.py | 62 +- proxy/core/connection/pool.py | 185 +++ proxy/core/connection/server.py | 55 +- proxy/core/connection/types.py | 21 + proxy/core/event/__init__.py | 5 +- proxy/core/event/dispatcher.py | 116 +- proxy/core/event/manager.py | 79 ++ proxy/core/event/names.py | 32 +- proxy/core/event/queue.py | 59 +- proxy/core/event/subscriber.py | 178 ++- proxy/core/listener/__init__.py | 24 + proxy/core/listener/base.py | 61 + proxy/core/listener/pool.py | 53 + proxy/core/listener/tcp.py | 84 ++ proxy/core/listener/unix.py | 48 + proxy/core/ssh/__init__.py | 12 + proxy/core/ssh/client.py | 28 - proxy/core/ssh/handler.py | 35 + proxy/core/ssh/listener.py | 138 +++ proxy/core/ssh/tunnel.py | 61 - proxy/core/threadless.py | 247 ---- proxy/{plugin/cache => core/tls}/__init__.py | 10 +- proxy/core/tls/certificate.py | 54 + proxy/core/tls/finished.py | 25 + proxy/core/tls/handshake.py | 126 +++ proxy/core/tls/hello.py | 242 ++++ proxy/core/tls/key_exchange.py | 40 + proxy/core/tls/pretty.py | 16 + proxy/core/tls/tls.py | 76 ++ proxy/core/tls/types.py | 41 + proxy/core/work/__init__.py | 32 + proxy/core/work/delegate.py | 43 + proxy/core/work/fd/__init__.py | 20 + proxy/core/work/fd/fd.py | 52 + proxy/core/work/fd/local.py | 50 + proxy/core/work/fd/remote.py | 53 + proxy/core/work/local.py | 42 + proxy/core/work/pool.py | 149 +++ proxy/core/work/remote.py | 39 + proxy/core/work/task/__init__.py | 24 + proxy/core/work/task/handler.py | 25 + proxy/core/work/task/local.py | 50 + proxy/core/work/task/remote.py | 48 + proxy/core/work/task/task.py | 18 + proxy/core/work/threaded.py | 54 + proxy/core/work/threadless.py | 425 +++++++ proxy/core/work/work.py | 111 ++ proxy/http/__init__.py | 18 + proxy/http/codes.py | 64 +- proxy/http/connection.py | 20 + proxy/http/descriptors.py | 40 + proxy/http/exception/__init__.py | 3 +- proxy/http/exception/base.py | 24 +- proxy/http/exception/http_request_rejected.py | 44 +- proxy/http/exception/proxy_auth_failed.py | 34 +- proxy/http/exception/proxy_conn_failed.py | 35 +- proxy/http/handler.py | 616 +++++----- proxy/http/headers.py | 30 + proxy/http/methods.py | 101 +- proxy/http/parser.py | 265 ----- proxy/http/parser/__init__.py | 30 + .../http/{chunk_parser.py => parser/chunk.py} | 28 +- proxy/http/parser/parser.py | 460 ++++++++ proxy/http/parser/protocol.py | 50 + proxy/http/parser/types.py | 37 + proxy/http/plugin.py | 99 ++ proxy/http/protocols.py | 34 + proxy/http/proxy/__init__.py | 1 + proxy/http/proxy/auth.py | 38 + proxy/http/proxy/plugin.py | 125 +- proxy/http/proxy/server.py | 1003 ++++++++++++----- proxy/http/responses.py | 153 +++ proxy/http/server/__init__.py | 3 +- .../__init__.py => http/server/middleware.py} | 6 + proxy/http/server/pac_plugin.py | 65 +- proxy/http/server/plugin.py | 159 ++- proxy/http/server/protocols.py | 19 +- proxy/http/server/web.py | 379 ++++--- proxy/http/url.py | 159 +++ proxy/http/websocket/__init__.py | 28 + proxy/http/websocket/client.py | 125 ++ .../http/{websocket.py => websocket/frame.py} | 159 +-- proxy/http/websocket/plugin.py | 65 ++ proxy/http/websocket/transport.py | 84 ++ .../neon_tx_send_iterative_strategy.py | 2 +- proxy/plugin/__init__.py | 32 +- proxy/plugin/cache/base.py | 60 - proxy/plugin/cache/cache_responses.py | 29 - proxy/plugin/cache/store/base.py | 36 - proxy/plugin/cache/store/disk.py | 49 - proxy/plugin/filter_by_upstream.py | 43 - proxy/plugin/man_in_the_middle.py | 36 - proxy/plugin/mock_rest_api.py | 88 -- proxy/plugin/modify_post_data.py | 47 - proxy/plugin/proxy_pool.py | 84 -- proxy/plugin/redirect_to_custom_server.py | 45 - proxy/plugin/reverse_proxy.py | 76 -- proxy/plugin/shortlink.py | 84 -- proxy/plugin/web_server_route.py | 48 - proxy/proxy.py | 340 +++++- tests/plugin/utils.py | 4 +- tests/test_main.py | 206 ---- 127 files changed, 8494 insertions(+), 3719 deletions(-) create mode 100644 proxy/common/_version.py create mode 100644 proxy/common/backports.py create mode 100644 proxy/common/flag.py delete mode 100644 proxy/common/flags.py create mode 100644 proxy/common/logger.py create mode 100644 proxy/common/plugins.py create mode 100644 proxy/core/base/__init__.py create mode 100644 proxy/core/base/tcp_server.py create mode 100644 proxy/core/base/tcp_tunnel.py create mode 100644 proxy/core/base/tcp_upstream.py create mode 100644 proxy/core/connection/pool.py create mode 100644 proxy/core/connection/types.py create mode 100644 proxy/core/event/manager.py create mode 100644 proxy/core/listener/__init__.py create mode 100644 proxy/core/listener/base.py create mode 100644 proxy/core/listener/pool.py create mode 100644 proxy/core/listener/tcp.py create mode 100644 proxy/core/listener/unix.py delete mode 100644 proxy/core/ssh/client.py create mode 100644 proxy/core/ssh/handler.py create mode 100644 proxy/core/ssh/listener.py delete mode 100644 proxy/core/ssh/tunnel.py delete mode 100644 proxy/core/threadless.py rename proxy/{plugin/cache => core/tls}/__init__.py (70%) create mode 100644 proxy/core/tls/certificate.py create mode 100644 proxy/core/tls/finished.py create mode 100644 proxy/core/tls/handshake.py create mode 100644 proxy/core/tls/hello.py create mode 100644 proxy/core/tls/key_exchange.py create mode 100644 proxy/core/tls/pretty.py create mode 100644 proxy/core/tls/tls.py create mode 100644 proxy/core/tls/types.py create mode 100644 proxy/core/work/__init__.py create mode 100644 proxy/core/work/delegate.py create mode 100644 proxy/core/work/fd/__init__.py create mode 100644 proxy/core/work/fd/fd.py create mode 100644 proxy/core/work/fd/local.py create mode 100644 proxy/core/work/fd/remote.py create mode 100644 proxy/core/work/local.py create mode 100644 proxy/core/work/pool.py create mode 100644 proxy/core/work/remote.py create mode 100644 proxy/core/work/task/__init__.py create mode 100644 proxy/core/work/task/handler.py create mode 100644 proxy/core/work/task/local.py create mode 100644 proxy/core/work/task/remote.py create mode 100644 proxy/core/work/task/task.py create mode 100644 proxy/core/work/threaded.py create mode 100644 proxy/core/work/threadless.py create mode 100644 proxy/core/work/work.py create mode 100644 proxy/http/connection.py create mode 100644 proxy/http/descriptors.py create mode 100644 proxy/http/headers.py delete mode 100644 proxy/http/parser.py create mode 100644 proxy/http/parser/__init__.py rename proxy/http/{chunk_parser.py => parser/chunk.py} (81%) create mode 100644 proxy/http/parser/parser.py create mode 100644 proxy/http/parser/protocol.py create mode 100644 proxy/http/parser/types.py create mode 100644 proxy/http/plugin.py create mode 100644 proxy/http/protocols.py create mode 100644 proxy/http/proxy/auth.py create mode 100644 proxy/http/responses.py rename proxy/{plugin/cache/store/__init__.py => http/server/middleware.py} (71%) create mode 100644 proxy/http/url.py create mode 100644 proxy/http/websocket/__init__.py create mode 100644 proxy/http/websocket/client.py rename proxy/http/{websocket.py => websocket/frame.py} (52%) create mode 100644 proxy/http/websocket/plugin.py create mode 100644 proxy/http/websocket/transport.py delete mode 100644 proxy/plugin/cache/base.py delete mode 100644 proxy/plugin/cache/cache_responses.py delete mode 100644 proxy/plugin/cache/store/base.py delete mode 100644 proxy/plugin/cache/store/disk.py delete mode 100644 proxy/plugin/filter_by_upstream.py delete mode 100644 proxy/plugin/man_in_the_middle.py delete mode 100644 proxy/plugin/mock_rest_api.py delete mode 100644 proxy/plugin/modify_post_data.py delete mode 100644 proxy/plugin/proxy_pool.py delete mode 100644 proxy/plugin/redirect_to_custom_server.py delete mode 100644 proxy/plugin/reverse_proxy.py delete mode 100644 proxy/plugin/shortlink.py delete mode 100644 proxy/plugin/web_server_route.py delete mode 100644 tests/test_main.py diff --git a/Makefile b/Makefile index 137e3ffd4..50101135c 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ CA_KEY_FILE_PATH := ca-key.pem CA_CERT_FILE_PATH := ca-cert.pem CA_SIGNING_KEY_FILE_PATH := ca-signing-key.pem -.PHONY: all https-certificates ca-certificates autopep8 devtools +.PHONY: all https-certificates ca-certificates autopep8 .PHONY: lib-version lib-clean lib-test lib-package lib-coverage lib-lint .PHONY: lib-release-test lib-release lib-profile .PHONY: container container-run container-release diff --git a/README.md b/README.md index 308a03772..9f7c5b0d5 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,6 @@ Table of Contents * [Man-In-The-Middle Plugin](#maninthemiddleplugin) * [Proxy Pool Plugin](#proxypoolplugin) * [HTTP Web Server Plugins](#http-web-server-plugins) - * [Reverse Proxy](#reverse-proxy) * [Web Server Route](#web-server-route) * [Plugin Ordering](#plugin-ordering) * [End-to-End Encryption](#end-to-end-encryption) @@ -1610,9 +1609,8 @@ usage: proxy [-h] [--backlog BACKLOG] [--basic-auth BASIC_AUTH] [--ca-signing-key-file CA_SIGNING_KEY_FILE] [--cert-file CERT_FILE] [--client-recvbuf-size CLIENT_RECVBUF_SIZE] - [--devtools-ws-path DEVTOOLS_WS_PATH] [--disable-headers DISABLE_HEADERS] [--disable-http-proxy] - [--enable-devtools] [--enable-events] + [--enable-events] [--enable-static-server] [--enable-web-server] [--hostname HOSTNAME] [--key-file KEY_FILE] [--log-level LOG_LEVEL] [--log-file LOG_FILE] @@ -1660,9 +1658,6 @@ optional arguments: the client in a single recv() operation. Bump this value for faster uploads at the expense of increased RAM. - --devtools-ws-path DEVTOOLS_WS_PATH - Default: /devtools. Only applicable if --enable- - devtools is used. --disable-headers DISABLE_HEADERS Default: None. Comma separated list of headers to remove before dispatching client request to upstream diff --git a/proxy/__init__.py b/proxy/__init__.py index 815923b42..388c28386 100755 --- a/proxy/__init__.py +++ b/proxy/__init__.py @@ -8,9 +8,8 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from .proxy import entry_point -from .proxy import main, start -from .proxy import Proxy +from .proxy import Proxy, main, sleep_loop, entry_point + __all__ = [ # PyPi package entry_point. See @@ -18,6 +17,10 @@ 'entry_point', # Embed proxy.py. See # https://github.com/abhinavsingh/proxy.py#embed-proxypy - 'main', 'start', + 'main', + # Unit testing with proxy.py. See + # https://github.com/abhinavsingh/proxy.py#unit-testing-with-proxypy 'Proxy', + # Utility exposed for demos + 'sleep_loop', ] diff --git a/proxy/common/_version.py b/proxy/common/_version.py new file mode 100644 index 000000000..c259a3cff --- /dev/null +++ b/proxy/common/_version.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + Version definition. +""" +from typing import Tuple, Union + + +__version__ = '2.4.3' +_ver_tup = 2, 4, 3 + + +def _to_int_or_str(inp: str) -> Union[int, str]: # pragma: no cover + try: + return int(inp) + except ValueError: + return inp + + +def _split_version_parts(inp: str) -> Tuple[str, ...]: # pragma: no cover + public_version, _plus, local_version = inp.partition('+') + return *public_version.split('.'), local_version + + +try: + VERSION = _ver_tup +except NameError: # pragma: no cover + VERSION = tuple( + map(_to_int_or_str, _split_version_parts(__version__)), + ) + + +__all__ = '__version__', 'VERSION' diff --git a/proxy/common/backports.py b/proxy/common/backports.py new file mode 100644 index 000000000..0855f5439 --- /dev/null +++ b/proxy/common/backports.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import threading +from queue import Empty +from typing import Any, Deque +from collections import deque + + +class cached_property: # pragma: no cover + """Decorator for read-only properties evaluated only once within TTL period. + It can be used to create a cached property like this:: + + import random + + # the class containing the property must be a new-style class + class MyClass: + # create property whose value is cached for ten minutes + @cached_property(ttl=600) + def randint(self): + # will only be evaluated every 10 min. at maximum. + return random.randint(0, 100) + + The value is cached in the '_cached_properties' attribute of the object instance that + has the property getter method wrapped by this decorator. The '_cached_properties' + attribute value is a dictionary which has a key for every property of the + object which is wrapped by this decorator. Each entry in the cache is + created only when the property is accessed for the first time and is a + two-element tuple with the last computed property value and the last time + it was updated in seconds since the epoch. + + The default time-to-live (TTL) is 0 seconds i.e. cached value will never expire. + + To expire a cached property value manually just do:: + del instance._cached_properties[] + + Adopted from https://wiki.python.org/moin/PythonDecoratorLibrary#Cached_Properties + © 2011 Christopher Arndt, MIT License. + + NOTE: We need this function only because Python in-built are only available + for 3.8+. Hence, we must get rid of this function once proxy.py no longer + support version older than 3.8. + + .. spelling:: + + backports + getter + Arndt + del + """ + + def __init__(self, ttl: float = 0): + self.ttl = ttl + + def __call__(self, fget: Any, doc: Any = None) -> 'cached_property': + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + self.__module__ = fget.__module__ + return self + + def __get__(self, inst: Any, owner: Any) -> Any: + now = time.time() + try: + value, last_update = inst._cached_properties[self.__name__] + if self.ttl > 0 and now - last_update > self.ttl: # noqa: WPS333 + raise AttributeError + except (KeyError, AttributeError): + value = self.fget(inst) + try: + cache = inst._cached_properties + except AttributeError: + cache, inst._cached_properties = {}, {} + finally: + cache[self.__name__] = (value, now) # pylint: disable=E0601 + return value + + +class NonBlockingQueue: + '''Simple, unbounded, non-blocking FIFO queue. + + Supports only a single consumer. + + NOTE: This is available in Python since 3.7 as SimpleQueue. + Here because proxy.py still supports 3.6 + ''' + + def __init__(self) -> None: + self._queue: Deque[Any] = deque() + self._count: threading.Semaphore = threading.Semaphore(0) + + def put(self, item: Any) -> None: + '''Put the item on the queue.''' + self._queue.append(item) + self._count.release() + + def get(self) -> Any: + '''Remove and return an item from the queue.''' + if not self._count.acquire(False, None): + raise Empty + return self._queue.popleft() + + def empty(self) -> bool: + '''Return True if the queue is empty, False otherwise (not reliable!).''' + return len(self._queue) == 0 # pragma: no cover + + def qsize(self) -> int: + '''Return the approximate size of the queue (not reliable!).''' + return len(self._queue) # pragma: no cover diff --git a/proxy/common/constants.py b/proxy/common/constants.py index c3349156a..2d9d8f2a8 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -9,26 +9,61 @@ :license: BSD, see LICENSE for more details. """ import os +import sys import time import pathlib +import secrets +import platform import ipaddress - -from typing import List +import sysconfig +from typing import Any, List from .version import __version__ + +SYS_PLATFORM = platform.system() +IS_WINDOWS = SYS_PLATFORM.lower() in ('windows', 'cygwin') + + +def _env_threadless_compliant() -> bool: + """Returns true for Python 3.8+ across all platforms + except Windows.""" + return not IS_WINDOWS and sys.version_info >= (3, 8) + + PROXY_PY_START_TIME = time.time() # /path/to/proxy.py/proxy folder PROXY_PY_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +# Path to virtualenv/lib/python3.X/site-packages +PROXY_PY_SITE_PACKAGES = sysconfig.get_path('purelib') +assert PROXY_PY_SITE_PACKAGES + CRLF = b'\r\n' COLON = b':' WHITESPACE = b' ' COMMA = b',' DOT = b'.' SLASH = b'/' -HTTP_1_1 = b'HTTP/1.1' +AT = b'@' +HTTP_PROTO = b'http' +HTTPS_PROTO = HTTP_PROTO + b's' +HTTP_1_0 = HTTP_PROTO.upper() + SLASH + b'1.0' +HTTP_1_1 = HTTP_PROTO.upper() + SLASH + b'1.1' +HTTP_URL_PREFIX = HTTP_PROTO + COLON + SLASH + SLASH +HTTPS_URL_PREFIX = HTTPS_PROTO + COLON + SLASH + SLASH + +LOCAL_INTERFACE_HOSTNAMES = ( + b'localhost', + b'127.0.0.1', + b'::1', +) + +ANY_INTERFACE_HOSTNAMES = ( + b'0.0.0.0', + b'::', +) PROXY_AGENT_HEADER_KEY = b'Proxy-agent' PROXY_AGENT_HEADER_VALUE = b'proxy.py v' + \ @@ -39,44 +74,84 @@ # Defaults DEFAULT_BACKLOG = 100 DEFAULT_BASIC_AUTH = None -DEFAULT_BUFFER_SIZE = 1024 * 1024 +DEFAULT_MAX_SEND_SIZE = 64 * 1024 +DEFAULT_BUFFER_SIZE = 128 * 1024 DEFAULT_CA_CERT_DIR = None DEFAULT_CA_CERT_FILE = None DEFAULT_CA_KEY_FILE = None DEFAULT_CA_SIGNING_KEY_FILE = None DEFAULT_CERT_FILE = None -DEFAULT_CA_FILE = None +DEFAULT_CA_FILE = pathlib.Path( + PROXY_PY_SITE_PACKAGES, +) / 'certifi' / 'cacert.pem' DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE -DEFAULT_DEVTOOLS_WS_PATH = b'/devtools' DEFAULT_DISABLE_HEADERS: List[bytes] = [] -DEFAULT_DISABLE_HTTP_PROXY = False +DEFAULT_ENABLE_SSH_TUNNEL = False DEFAULT_ENABLE_EVENTS = False DEFAULT_EVENTS_QUEUE = None DEFAULT_ENABLE_STATIC_SERVER = False DEFAULT_ENABLE_WEB_SERVER = False +DEFAULT_ALLOWED_URL_SCHEMES = [HTTP_PROTO, HTTPS_PROTO] DEFAULT_IPV4_HOSTNAME = ipaddress.IPv4Address('127.0.0.1') DEFAULT_IPV6_HOSTNAME = ipaddress.IPv6Address('::1') DEFAULT_KEY_FILE = None DEFAULT_LOG_FILE = None -DEFAULT_LOG_FORMAT = '%(asctime)s - pid:%(process)d [%(levelname)-.1s] %(funcName)s:%(lineno)d - %(message)s' +DEFAULT_LOG_FORMAT = '%(asctime)s - pid:%(process)d [%(levelname)-.1s] %(module)s.%(funcName)s:%(lineno)d - %(message)s' DEFAULT_LOG_LEVEL = 'INFO' +DEFAULT_WEB_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' \ + '{request_method} {request_path} - {request_ua} - {connection_time_ms}ms' +DEFAULT_HTTP_PROXY_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ + '{request_method} {server_host}:{server_port}{request_path} - ' + \ + '{response_code} {response_reason} - {response_bytes} bytes - ' + \ + '{connection_time_ms}ms' +DEFAULT_HTTPS_PROXY_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ + '{request_method} {server_host}:{server_port} - ' + \ + '{response_bytes} bytes - {connection_time_ms}ms' +DEFAULT_NUM_ACCEPTORS = 0 DEFAULT_NUM_WORKERS = 0 DEFAULT_OPEN_FILE_LIMIT = 1024 DEFAULT_PAC_FILE = None DEFAULT_PAC_FILE_URL_PATH = b'/' DEFAULT_PID_FILE = None -DEFAULT_PLUGINS = '' +DEFAULT_PORT_FILE = None +DEFAULT_PLUGINS: List[Any] = [] DEFAULT_PORT = 8899 DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE DEFAULT_STATIC_SERVER_DIR = os.path.join(PROXY_PY_DIR, "public") -DEFAULT_THREADLESS = False -DEFAULT_TIMEOUT = 10 +DEFAULT_MIN_COMPRESSION_LENGTH = 20 # In bytes +DEFAULT_THREADLESS = _env_threadless_compliant() +DEFAULT_LOCAL_EXECUTOR = True +DEFAULT_TIMEOUT = 10.0 DEFAULT_VERSION = False DEFAULT_HTTP_PORT = 80 -DEFAULT_MAX_SEND_SIZE = 16 * 1024 +DEFAULT_HTTPS_PORT = 443 +DEFAULT_WORK_KLASS = 'proxy.http.HttpProtocolHandler' +DEFAULT_ENABLE_PROXY_PROTOCOL = False +# 25 milliseconds to keep the loops hot +# Will consume ~0.3-0.6% CPU when idle. +DEFAULT_SELECTOR_SELECT_TIMEOUT = 25 / 1000 +DEFAULT_WAIT_FOR_TASKS_TIMEOUT = 1 / 1000 +DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT = 1 # in seconds DEFAULT_DATA_DIRECTORY_PATH = os.path.join(str(pathlib.Path.home()), '.proxy') +DEFAULT_CACHE_DIRECTORY_PATH = os.path.join( + DEFAULT_DATA_DIRECTORY_PATH, 'cache', +) +DEFAULT_CACHE_REQUESTS = False +DEFAULT_CACHE_BY_CONTENT_TYPE = False # Cor plugins enabled by default or via flags +DEFAULT_ABC_PLUGINS = [ + 'HttpProtocolHandlerPlugin', + 'HttpProxyBasePlugin', + 'HttpWebServerBasePlugin', + 'WebSocketTransportBasePlugin', +] PLUGIN_HTTP_PROXY = 'proxy.http.proxy.HttpProxyPlugin' PLUGIN_WEB_SERVER = 'proxy.http.server.HttpWebServerPlugin' +PLUGIN_PAC_FILE = 'proxy.http.server.HttpWebServerPacFilePlugin' +PLUGIN_WEBSOCKET_TRANSPORT = 'proxy.http.websocket.transport.WebSocketTransport' + +PY2_DEPRECATION_MESSAGE = '''DEPRECATION: proxy.py no longer supports Python 2.7. Kindly upgrade to Python 3+. ' + 'If for some reasons you cannot upgrade, use' + '"pip install proxy.py==0.3".''' diff --git a/proxy/common/flag.py b/proxy/common/flag.py new file mode 100644 index 000000000..f09e96c65 --- /dev/null +++ b/proxy/common/flag.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import os +import sys +import base64 +import socket +import argparse +import ipaddress +import itertools +import collections +import multiprocessing +from typing import Any, List, Optional, cast + +from .types import IpAddress +from .utils import bytes_, is_py2, is_threadless, set_open_file_limit +from .logger import Logger +from .plugins import Plugins +from .version import __version__ +from .constants import ( + COMMA, IS_WINDOWS, PLUGIN_PAC_FILE, PLUGIN_HTTP_PROXY, + PLUGIN_WEB_SERVER, DEFAULT_NUM_WORKERS, + DEFAULT_NUM_ACCEPTORS, + DEFAULT_DISABLE_HEADERS, PY2_DEPRECATION_MESSAGE, + DEFAULT_DATA_DIRECTORY_PATH, DEFAULT_MIN_COMPRESSION_LENGTH, +) + + +__homepage__ = 'https://github.com/abhinavsingh/proxy.py' + + +# TODO: Currently `initialize` staticmethod contains knowledge +# about several common flags defined by proxy.py core. + +# This logic must be decoupled. flags.add_argument must +# also provide a callback to resolve the final flag value +# based upon availability in input_args, **opts and +# default values. + +# Supporting such a framework is complex but achievable. +# One problem is that resolution of certain flags +# can depend upon availability of other flags. + +# This will lead us into dependency graph modeling domain. +class FlagParser: + """Wrapper around argparse module. + + Import `flag.flags` and use `add_argument` API + to define custom flags within respective Python files. + + Best Practice:: + + 1. Define flags at the top of your class files. + 2. DO NOT add flags within your class `__init__` method OR + within class methods. It MAY result into runtime exception, + especially if your class is initialized multiple times or if + class method registering the flag gets invoked multiple times. + + """ + + def __init__(self) -> None: + self.args: Optional[argparse.Namespace] = None + self.actions: List[str] = [] + self.parser = argparse.ArgumentParser( + description='proxy.py v%s' % __version__, + epilog='Proxy.py not working? Report at: %s/issues/new' % __homepage__, + ) + + def add_argument(self, *args: Any, **kwargs: Any) -> argparse.Action: + """Register a flag.""" + action = self.parser.add_argument(*args, **kwargs) + self.actions.append(action.dest) + return action + + def parse_args( + self, input_args: Optional[List[str]], + ) -> argparse.Namespace: + """Parse flags from input arguments.""" + self.args = self.parser.parse_args(input_args) + return self.args + + @staticmethod + def initialize( + input_args: Optional[List[str]] = None, + **opts: Any, + ) -> argparse.Namespace: + if input_args is None: + input_args = [] + + if is_py2(): + print(PY2_DEPRECATION_MESSAGE) + sys.exit(1) + + # Discover flags from requested plugin. + # This will also surface external plugin flags + # under --help. + Plugins.discover(input_args) + + # Parse flags + args = flags.parse_args(input_args) + + # Print version and exit + if args.version: + print(__version__) + sys.exit(0) + + # proxy.py currently cannot serve over HTTPS and also perform TLS interception + # at the same time. Check if user is trying to enable both feature + # at the same time. + # + # TODO: Use parser.add_mutually_exclusive_group() + # and remove this logic from here. + if (args.cert_file and args.key_file) and \ + (args.ca_key_file and args.ca_cert_file and args.ca_signing_key_file): + print( + 'You can either enable end-to-end encryption OR TLS interception,' + 'not both together.', + ) + sys.exit(1) + + # Setup logging module + Logger.setup(args.log_file, args.log_level, args.log_format) + + # Setup limits + set_open_file_limit(args.open_file_limit) + + # Load work_klass + work_klass = opts.get('work_klass', args.work_klass) + work_klass = Plugins.importer(bytes_(work_klass))[0] \ + if isinstance(work_klass, str) \ + else work_klass + + # --enable flags must be parsed before loading plugins + # otherwise we will miss the plugins passed via constructor + args.enable_web_server = cast( + bool, + opts.get( + 'enable_web_server', + args.enable_web_server, + ), + ) + args.enable_static_server = cast( + bool, + opts.get( + 'enable_static_server', + args.enable_static_server, + ), + ) + args.enable_events = cast( + bool, + opts.get( + 'enable_events', + args.enable_events, + ), + ) + + # Load default plugins along with user provided --plugins + default_plugins = [ + bytes_(p) + for p in FlagParser.get_default_plugins(args) + ] + requested_plugins = Plugins.resolve_plugin_flag( + args.plugins, opts.get('plugins'), + ) + plugins = Plugins.load( + default_plugins + requested_plugins, + ) + + # https://github.com/python/mypy/issues/5865 + # + # def option(t: object, key: str, default: Any) -> Any: + # return cast(t, opts.get(key, default)) + args.work_klass = work_klass + args.plugins = plugins + args.server_recvbuf_size = cast( + int, + opts.get( + 'server_recvbuf_size', + args.server_recvbuf_size, + ), + ) + args.client_recvbuf_size = cast( + int, + opts.get( + 'client_recvbuf_size', + args.client_recvbuf_size, + ), + ) + args.pac_file = cast( + Optional[str], opts.get( + 'pac_file', bytes_( + args.pac_file, + ), + ), + ) + args.pac_file_url_path = cast( + Optional[bytes], opts.get( + 'pac_file_url_path', bytes_( + args.pac_file_url_path, + ), + ), + ) + disabled_headers = cast( + Optional[List[bytes]], opts.get( + 'disable_headers', [ + header.lower() + for header in bytes_(getattr(args, 'disable_headers', b'')).split(COMMA) + if header.strip() != b'' + ], + ), + ) + args.disable_headers = disabled_headers if disabled_headers is not None else DEFAULT_DISABLE_HEADERS + + args.certfile = cast( + Optional[str], opts.get( + 'cert_file', args.cert_file, + ), + ) + args.keyfile = cast(Optional[str], opts.get('key_file', args.key_file)) + + args.ca_key_file = cast( + Optional[str], opts.get( + 'ca_key_file', getattr(args, 'ca_key_file', None) + ), + ) + args.ca_cert_file = cast( + Optional[str], opts.get( + 'ca_cert_file', getattr(args, 'ca_cert_file', None), + ), + ) + args.ca_signing_key_file = cast( + Optional[str], + opts.get( + 'ca_signing_key_file', + getattr(args, 'ca_signing_key_file', None) + ), + ) + args.ca_file = cast( + Optional[str], + opts.get( + 'ca_file', + getattr(args, 'ca_file', None) + ), + ) + args.openssl = cast( + Optional[str], + opts.get( + 'openssl', + getattr(args, 'openssl', None) + ), + ) + + args.hostname = cast( + IpAddress, + opts.get('hostname', ipaddress.ip_address(args.hostname)), + ) + args.unix_socket_path = opts.get( + 'unix_socket_path', getattr(args, 'unix_socket_path', None), + ) + # AF_UNIX is not available on Windows + # See https://bugs.python.org/issue33408 + if not IS_WINDOWS: + args.family = socket.AF_UNIX if args.unix_socket_path else ( + socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET + ) + else: + # FIXME: Not true for tests, as this value will be a mock. + # + # It's a problem only on Windows. Instead of a proper + # fix in the tests, simply commenting this line of assertion + # for now. + # + # assert args.unix_socket_path is None + args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET + args.port = cast(int, opts.get('port', args.port)) + ports: List[List[int]] = opts.get('ports', args.ports) + args.ports = [ + int(port) for port in list( + itertools.chain.from_iterable([] if ports is None else ports), + ) + ] + args.backlog = cast(int, opts.get('backlog', args.backlog)) + num_workers = opts.get('num_workers', args.num_workers) + args.num_workers = cast( + int, num_workers if num_workers > 0 else multiprocessing.cpu_count(), + ) + num_acceptors = opts.get('num_acceptors', args.num_acceptors) + # See https://github.com/abhinavsingh/proxy.py/pull/714 description + # to understand rationale behind the following logic. + # + # Num workers flag or option was found. We will use + # the same value for num_acceptors when num acceptors flag + # is absent. + if num_workers != DEFAULT_NUM_WORKERS and num_acceptors == DEFAULT_NUM_ACCEPTORS: + args.num_acceptors = args.num_workers + else: + args.num_acceptors = cast( + int, num_acceptors if num_acceptors > 0 else multiprocessing.cpu_count(), + ) + + args.static_server_dir = cast( + str, + opts.get( + 'static_server_dir', + args.static_server_dir, + ), + ) + args.min_compression_length = cast( + bool, + opts.get( + 'min_compression_length', + getattr( + args, 'min_compression_length', + DEFAULT_MIN_COMPRESSION_LENGTH, + ), + ), + ) + args.timeout = cast(int, opts.get('timeout', args.timeout)) + args.local_executor = cast( + int, + opts.get( + 'local_executor', + args.local_executor, + ), + ) + args.threaded = cast(bool, opts.get('threaded', args.threaded)) + # Pre-evaluate threadless values based upon environment and config + # + # --threadless is now default mode of execution + # but we still have exceptions based upon OS config. + # Make sure executors are not started if is_threadless + # evaluates to False. + args.threadless = cast(bool, opts.get('threadless', args.threadless)) + args.threadless = is_threadless(args.threadless, args.threaded) + + args.pid_file = cast( + Optional[str], opts.get( + 'pid_file', + args.pid_file, + ), + ) + + args.port_file = cast( + Optional[str], opts.get( + 'port_file', + args.port_file, + ), + ) + + args.proxy_py_data_dir = DEFAULT_DATA_DIRECTORY_PATH + os.makedirs(args.proxy_py_data_dir, exist_ok=True) + + ca_cert_dir = opts.get('ca_cert_dir', getattr(args, 'ca_cert_dir', None)) + args.ca_cert_dir = cast(Optional[str], ca_cert_dir) + if args.ca_cert_dir is None: + args.ca_cert_dir = os.path.join( + args.proxy_py_data_dir, 'certificates', + ) + os.makedirs(args.ca_cert_dir, exist_ok=True) + + # FIXME: Necessary here until flags framework provides a way + # for flag owners to initialize + args.cache_dir = getattr(args, 'cache_dir', 'cache') + os.makedirs(args.cache_dir, exist_ok=True) + os.makedirs(os.path.join(args.cache_dir, 'responses'), exist_ok=True) + os.makedirs(os.path.join(args.cache_dir, 'content'), exist_ok=True) + + return args + + @staticmethod + def get_default_plugins( + args: argparse.Namespace, + ) -> List[str]: + """Prepare list of plugins to load based upon + --enable-* and --disable-* flags. + """ + default_plugins: List[str] = [] + default_plugins.append(PLUGIN_HTTP_PROXY) + if args.enable_web_server or \ + args.pac_file is not None or \ + args.enable_static_server: + default_plugins.append(PLUGIN_WEB_SERVER) + if args.pac_file is not None: + default_plugins.append(PLUGIN_PAC_FILE) + return list(collections.OrderedDict.fromkeys(default_plugins).keys()) + + +flags = FlagParser() diff --git a/proxy/common/flags.py b/proxy/common/flags.py deleted file mode 100644 index 7aa279913..000000000 --- a/proxy/common/flags.py +++ /dev/null @@ -1,541 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import abc -import logging -import importlib -import collections -import argparse -import base64 -import ipaddress -import os -import socket -import multiprocessing -import sys -import inspect - -from typing import Optional, Union, Dict, List, TypeVar, Type, cast, Any, Tuple - -from .utils import text_, bytes_ -from .constants import DEFAULT_LOG_LEVEL, DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_BACKLOG, DEFAULT_BASIC_AUTH -from .constants import DEFAULT_TIMEOUT, DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HTTP_PROXY, DEFAULT_DISABLE_HEADERS -from .constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_EVENTS -from .constants import DEFAULT_ENABLE_WEB_SERVER, DEFAULT_THREADLESS, DEFAULT_CERT_FILE, DEFAULT_KEY_FILE, DEFAULT_CA_FILE -from .constants import DEFAULT_CA_CERT_DIR, DEFAULT_CA_CERT_FILE, DEFAULT_CA_KEY_FILE, DEFAULT_CA_SIGNING_KEY_FILE -from .constants import DEFAULT_PAC_FILE_URL_PATH, DEFAULT_PAC_FILE, DEFAULT_PLUGINS, DEFAULT_PID_FILE, DEFAULT_PORT -from .constants import DEFAULT_NUM_WORKERS, DEFAULT_VERSION, DEFAULT_OPEN_FILE_LIMIT, DEFAULT_IPV6_HOSTNAME -from .constants import DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_STATIC_SERVER_DIR -from .constants import DEFAULT_DATA_DIRECTORY_PATH, COMMA, DOT -from .constants import PLUGIN_HTTP_PROXY, PLUGIN_WEB_SERVER -from .version import __version__ - -__homepage__ = 'https://github.com/abhinavsingh/proxy.py' - -if os.name != 'nt': - import resource - -logger = logging.getLogger(__name__) - -T = TypeVar('T', bound='Flags') - - -class Flags: - """Contains all input flags and inferred input parameters.""" - - def __init__( - self, - auth_code: Optional[bytes] = DEFAULT_BASIC_AUTH, - server_recvbuf_size: int = DEFAULT_SERVER_RECVBUF_SIZE, - client_recvbuf_size: int = DEFAULT_CLIENT_RECVBUF_SIZE, - pac_file: Optional[str] = DEFAULT_PAC_FILE, - pac_file_url_path: Optional[bytes] = DEFAULT_PAC_FILE_URL_PATH, - plugins: Optional[Dict[bytes, List[type]]] = None, - disable_headers: Optional[List[bytes]] = None, - certfile: Optional[str] = None, - keyfile: Optional[str] = None, - ca_cert_dir: Optional[str] = None, - ca_key_file: Optional[str] = None, - ca_cert_file: Optional[str] = None, - ca_signing_key_file: Optional[str] = None, - ca_file: Optional[str] = None, - num_workers: int = 0, - hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME, - port: int = DEFAULT_PORT, - backlog: int = DEFAULT_BACKLOG, - static_server_dir: str = DEFAULT_STATIC_SERVER_DIR, - enable_static_server: bool = DEFAULT_ENABLE_STATIC_SERVER, - devtools_ws_path: bytes = DEFAULT_DEVTOOLS_WS_PATH, - timeout: int = DEFAULT_TIMEOUT, - threadless: bool = DEFAULT_THREADLESS, - enable_events: bool = DEFAULT_ENABLE_EVENTS, - pid_file: Optional[str] = DEFAULT_PID_FILE) -> None: - self.pid_file = pid_file - self.threadless = threadless - self.timeout = timeout - self.auth_code = auth_code - self.server_recvbuf_size = server_recvbuf_size - self.client_recvbuf_size = client_recvbuf_size - self.pac_file = pac_file - self.pac_file_url_path = pac_file_url_path - if plugins is None: - plugins = {} - self.plugins: Dict[bytes, List[type]] = plugins - if disable_headers is None: - disable_headers = DEFAULT_DISABLE_HEADERS - self.disable_headers = disable_headers - self.certfile: Optional[str] = certfile - self.keyfile: Optional[str] = keyfile - self.ca_key_file: Optional[str] = ca_key_file - self.ca_cert_file: Optional[str] = ca_cert_file - self.ca_signing_key_file: Optional[str] = ca_signing_key_file - self.ca_file = ca_file - self.num_workers: int = num_workers if num_workers > 0 else multiprocessing.cpu_count() - self.hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = hostname - self.family: socket.AddressFamily = socket.AF_INET6 if hostname.version == 6 else socket.AF_INET - self.port: int = port - self.backlog: int = backlog - - self.enable_static_server: bool = enable_static_server - self.static_server_dir: str = static_server_dir - self.devtools_ws_path: bytes = devtools_ws_path - self.enable_events: bool = enable_events - - self.proxy_py_data_dir = DEFAULT_DATA_DIRECTORY_PATH - os.makedirs(self.proxy_py_data_dir, exist_ok=True) - - self.ca_cert_dir: Optional[str] = ca_cert_dir - if self.ca_cert_dir is None: - self.ca_cert_dir = os.path.join( - self.proxy_py_data_dir, 'certificates') - os.makedirs(self.ca_cert_dir, exist_ok=True) - - def tls_interception_enabled(self) -> bool: - return self.ca_key_file is not None and \ - self.ca_cert_dir is not None and \ - self.ca_signing_key_file is not None and \ - self.ca_cert_file is not None - - def encryption_enabled(self) -> bool: - return self.keyfile is not None and \ - self.certfile is not None - - @classmethod - def initialize( - cls: Type[T], - input_args: Optional[List[str]], - **opts: Any) -> T: - if not Flags.is_py3(): - print( - 'DEPRECATION: "develop" branch no longer supports Python 2.7. Kindly upgrade to Python 3+. ' - 'If for some reasons you cannot upgrade, consider using "master" branch or simply ' - '"pip install proxy.py==0.3".' - '\n\n' - 'DEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. ' - 'Please upgrade your Python as Python 2.7 won\'t be maintained after that date. ' - 'A future version of pip will drop support for Python 2.7.') - sys.exit(1) - - parser = Flags.init_parser() - args = parser.parse_args(input_args) - - # Print version and exit - if args.version: - print(__version__) - sys.exit(0) - - # Setup logging module - Flags.setup_logger(args.log_file, args.log_level, args.log_format) - - # Setup limits - Flags.set_open_file_limit(args.open_file_limit) - - default_plugins: List[Tuple[str, bool]] = [] - if not args.disable_http_proxy: - default_plugins.append((PLUGIN_HTTP_PROXY, True)) - if args.enable_web_server or \ - args.pac_file is not None or \ - args.enable_static_server: - default_plugins.append((PLUGIN_WEB_SERVER, True)) - - plugins = Flags.load_plugins( - bytes_( - '%s,%s' % - (text_(COMMA).join(collections.OrderedDict(default_plugins).keys()), - opts.get('plugins', args.plugins)))) - - # proxy.py currently cannot serve over HTTPS and perform TLS interception - # at the same time. Check if user is trying to enable both feature - # at the same time. - if (args.cert_file and args.key_file) and \ - (args.ca_key_file and args.ca_cert_file and args.ca_signing_key_file): - print('You can either enable end-to-end encryption OR TLS interception,' - 'not both together.') - sys.exit(1) - - # Generate auth_code required for basic authentication if enabled - auth_code = None - if args.basic_auth: - auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth)) - - return cls( - auth_code=cast(Optional[bytes], opts.get('auth_code', auth_code)), - server_recvbuf_size=cast( - int, - opts.get( - 'server_recvbuf_size', - args.server_recvbuf_size)), - client_recvbuf_size=cast( - int, - opts.get( - 'client_recvbuf_size', - args.client_recvbuf_size)), - pac_file=cast( - Optional[str], opts.get( - 'pac_file', bytes_( - args.pac_file))), - pac_file_url_path=cast( - Optional[bytes], opts.get( - 'pac_file_url_path', bytes_( - args.pac_file_url_path))), - disable_headers=cast(Optional[List[bytes]], opts.get('disable_headers', [ - header.lower() for header in bytes_( - args.disable_headers).split(COMMA) if header.strip() != b''])), - certfile=cast( - Optional[str], opts.get( - 'cert_file', args.cert_file)), - keyfile=cast(Optional[str], opts.get('key_file', args.key_file)), - ca_cert_dir=cast( - Optional[str], opts.get( - 'ca_cert_dir', args.ca_cert_dir)), - ca_key_file=cast( - Optional[str], opts.get( - 'ca_key_file', args.ca_key_file)), - ca_cert_file=cast( - Optional[str], opts.get( - 'ca_cert_file', args.ca_cert_file)), - ca_signing_key_file=cast( - Optional[str], - opts.get( - 'ca_signing_key_file', - args.ca_signing_key_file)), - ca_file=cast( - Optional[str], - opts.get( - 'ca_file', - args.ca_file)), - hostname=cast(Union[ipaddress.IPv4Address, - ipaddress.IPv6Address], - opts.get('hostname', ipaddress.ip_address(args.hostname))), - port=cast(int, opts.get('port', args.port)), - backlog=cast(int, opts.get('backlog', args.backlog)), - num_workers=cast(int, opts.get('num_workers', args.num_workers)), - static_server_dir=cast( - str, - opts.get( - 'static_server_dir', - args.static_server_dir)), - enable_static_server=cast( - bool, - opts.get( - 'enable_static_server', - args.enable_static_server)), - devtools_ws_path=cast( - bytes, - opts.get( - 'devtools_ws_path', - args.devtools_ws_path)), - timeout=cast(int, opts.get('timeout', args.timeout)), - threadless=cast(bool, opts.get('threadless', args.threadless)), - enable_events=cast( - bool, - opts.get( - 'enable_events', - args.enable_events)), - plugins=plugins, - pid_file=cast(Optional[str], opts.get('pid_file', args.pid_file))) - - @staticmethod - def init_parser() -> argparse.ArgumentParser: - """Initializes and returns argument parser.""" - parser = argparse.ArgumentParser( - description='proxy.py v%s' % __version__, - epilog='Proxy.py not working? Report at: %s/issues/new' % __homepage__ - ) - # Argument names are ordered alphabetically. - parser.add_argument( - '--backlog', - type=int, - default=DEFAULT_BACKLOG, - help='Default: 100. Maximum number of pending connections to proxy server') - parser.add_argument( - '--basic-auth', - type=str, - default=DEFAULT_BASIC_AUTH, - help='Default: No authentication. Specify colon separated user:password ' - 'to enable basic authentication.') - parser.add_argument( - '--ca-key-file', - type=str, - default=DEFAULT_CA_KEY_FILE, - help='Default: None. CA key to use for signing dynamically generated ' - 'HTTPS certificates. If used, must also pass --ca-cert-file and --ca-signing-key-file' - ) - parser.add_argument( - '--ca-cert-dir', - type=str, - default=DEFAULT_CA_CERT_DIR, - help='Default: ~/.proxy.py. Directory to store dynamically generated certificates. ' - 'Also see --ca-key-file, --ca-cert-file and --ca-signing-key-file' - ) - parser.add_argument( - '--ca-cert-file', - type=str, - default=DEFAULT_CA_CERT_FILE, - help='Default: None. Signing certificate to use for signing dynamically generated ' - 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-signing-key-file' - ) - parser.add_argument( - '--ca-file', - type=str, - default=DEFAULT_CA_FILE, - help='Default: None. Provide path to custom CA file for peer certificate validation. ' - 'Specially useful on MacOS.' - ) - parser.add_argument( - '--ca-signing-key-file', - type=str, - default=DEFAULT_CA_SIGNING_KEY_FILE, - help='Default: None. CA signing key to use for dynamic generation of ' - 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-cert-file' - ) - parser.add_argument( - '--cert-file', - type=str, - default=DEFAULT_CERT_FILE, - help='Default: None. Server certificate to enable end-to-end TLS encryption with clients. ' - 'If used, must also pass --key-file.' - ) - parser.add_argument( - '--client-recvbuf-size', - type=int, - default=DEFAULT_CLIENT_RECVBUF_SIZE, - help='Default: 1 MB. Maximum amount of data received from the ' - 'client in a single recv() operation. Bump this ' - 'value for faster uploads at the expense of ' - 'increased RAM.') - parser.add_argument( - '--devtools-ws-path', - type=str, - default=DEFAULT_DEVTOOLS_WS_PATH, - help='Default: /devtools. Only applicable ' - 'if --enable-devtools is used.' - ) - parser.add_argument( - '--disable-headers', - type=str, - default=COMMA.join(DEFAULT_DISABLE_HEADERS), - help='Default: None. Comma separated list of headers to remove before ' - 'dispatching client request to upstream server.') - parser.add_argument( - '--disable-http-proxy', - action='store_true', - default=DEFAULT_DISABLE_HTTP_PROXY, - help='Default: False. Whether to disable proxy.HttpProxyPlugin.') - parser.add_argument( - '--enable-events', - action='store_true', - default=DEFAULT_ENABLE_EVENTS, - help='Default: False. Enables core to dispatch lifecycle events. ' - 'Plugins can be used to subscribe for core events.' - ) - parser.add_argument( - '--enable-static-server', - action='store_true', - default=DEFAULT_ENABLE_STATIC_SERVER, - help='Default: False. Enable inbuilt static file server. ' - 'Optionally, also use --static-server-dir to serve static content ' - 'from custom directory. By default, static file server serves ' - 'out of installed proxy.py python module folder.' - ) - parser.add_argument( - '--enable-web-server', - action='store_true', - default=DEFAULT_ENABLE_WEB_SERVER, - help='Default: False. Whether to enable proxy.HttpWebServerPlugin.') - parser.add_argument( - '--hostname', - type=str, - default=str(DEFAULT_IPV6_HOSTNAME), - help='Default: ::1. Server IP address.') - parser.add_argument( - '--key-file', - type=str, - default=DEFAULT_KEY_FILE, - help='Default: None. Server key file to enable end-to-end TLS encryption with clients. ' - 'If used, must also pass --cert-file.' - ) - parser.add_argument( - '--log-level', - type=str, - default=DEFAULT_LOG_LEVEL, - help='Valid options: DEBUG, INFO (default), WARNING, ERROR, CRITICAL. ' - 'Both upper and lowercase values are allowed. ' - 'You may also simply use the leading character e.g. --log-level d') - parser.add_argument('--log-file', type=str, default=DEFAULT_LOG_FILE, - help='Default: sys.stdout. Log file destination.') - parser.add_argument('--log-format', type=str, default=DEFAULT_LOG_FORMAT, - help='Log format for Python logger.') - parser.add_argument('--num-workers', type=int, default=DEFAULT_NUM_WORKERS, - help='Defaults to number of CPU cores.') - parser.add_argument( - '--open-file-limit', - type=int, - default=DEFAULT_OPEN_FILE_LIMIT, - help='Default: 1024. Maximum number of files (TCP connections) ' - 'that proxy.py can open concurrently.') - parser.add_argument( - '--pac-file', - type=str, - default=DEFAULT_PAC_FILE, - help='A file (Proxy Auto Configuration) or string to serve when ' - 'the server receives a direct file request. ' - 'Using this option enables proxy.HttpWebServerPlugin.') - parser.add_argument( - '--pac-file-url-path', - type=str, - default=text_(DEFAULT_PAC_FILE_URL_PATH), - help='Default: %s. Web server path to serve the PAC file.' % - text_(DEFAULT_PAC_FILE_URL_PATH)) - parser.add_argument( - '--pid-file', - type=str, - default=DEFAULT_PID_FILE, - help='Default: None. Save parent process ID to a file.') - parser.add_argument( - '--plugins', - type=str, - default=DEFAULT_PLUGINS, - help='Comma separated plugins') - parser.add_argument('--port', type=int, default=DEFAULT_PORT, - help='Default: 8899. Server port.') - parser.add_argument( - '--server-recvbuf-size', - type=int, - default=DEFAULT_SERVER_RECVBUF_SIZE, - help='Default: 1 MB. Maximum amount of data received from the ' - 'server in a single recv() operation. Bump this ' - 'value for faster downloads at the expense of ' - 'increased RAM.') - parser.add_argument( - '--static-server-dir', - type=str, - default=DEFAULT_STATIC_SERVER_DIR, - help='Default: "public" folder in directory where proxy.py is placed. ' - 'This option is only applicable when static server is also enabled. ' - 'See --enable-static-server.' - ) - parser.add_argument( - '--threadless', - action='store_true', - default=DEFAULT_THREADLESS, - help='Default: False. When disabled a new thread is spawned ' - 'to handle each client connection.' - ) - parser.add_argument( - '--timeout', - type=int, - default=DEFAULT_TIMEOUT, - help='Default: ' + str(DEFAULT_TIMEOUT) + - '. Number of seconds after which ' - 'an inactive connection must be dropped. Inactivity is defined by no ' - 'data sent or received by the client.' - ) - parser.add_argument( - '--version', - '-v', - action='store_true', - default=DEFAULT_VERSION, - help='Prints proxy.py version.') - return parser - - @staticmethod - def set_open_file_limit(soft_limit: int) -> None: - """Configure open file description soft limit on supported OS.""" - if os.name != 'nt': # resource module not available on Windows OS - curr_soft_limit, curr_hard_limit = resource.getrlimit( - resource.RLIMIT_NOFILE) - if curr_soft_limit < soft_limit < curr_hard_limit: - resource.setrlimit( - resource.RLIMIT_NOFILE, (soft_limit, curr_hard_limit)) - logger.debug( - 'Open file soft limit set to %d', soft_limit) - - @staticmethod - def load_plugins(plugins: bytes) -> Dict[bytes, List[type]]: - """Accepts a comma separated list of Python modules and returns - a list of respective Python classes.""" - p: Dict[bytes, List[type]] = { - b'HttpProtocolHandlerPlugin': [], - b'HttpProxyBasePlugin': [], - b'HttpWebServerBasePlugin': [], - } - for plugin_ in plugins.split(COMMA): - plugin = text_(plugin_.strip()) - if plugin == '': - continue - module_name, klass_name = plugin.rsplit(text_(DOT), 1) - klass = getattr( - importlib.import_module( - module_name.replace( - os.path.sep, text_(DOT))), - klass_name) - mro = list(inspect.getmro(klass)) - mro.reverse() - iterator = iter(mro) - while next(iterator) is not abc.ABC: - pass - base_klass = next(iterator) - p[bytes_(base_klass.__name__)].append(klass) - logger.info( - 'Loaded %s %s.%s', - 'plugin' if klass.__name__ != 'HttpWebServerRouteHandler' else 'route', - module_name, - # HttpWebServerRouteHandler route decorator adds a special - # staticmethod to return decorated function name - klass.__name__ if klass.__name__ != 'HttpWebServerRouteHandler' else klass.name()) - return p - - @staticmethod - def setup_logger( - log_file: Optional[str] = DEFAULT_LOG_FILE, - log_level: str = DEFAULT_LOG_LEVEL, - log_format: str = DEFAULT_LOG_FORMAT) -> None: - ll = getattr( - logging, - {'D': 'DEBUG', - 'I': 'INFO', - 'W': 'WARNING', - 'E': 'ERROR', - 'C': 'CRITICAL'}[log_level.upper()[0]]) - if log_file: - logging.basicConfig( - filename=log_file, - filemode='a', - level=ll, - format=log_format) - else: - logging.basicConfig(level=ll, format=log_format) - - @staticmethod - def is_py3() -> bool: - """Exists only to avoid mocking sys.version_info in tests.""" - return sys.version_info[0] == 3 diff --git a/proxy/common/logger.py b/proxy/common/logger.py new file mode 100644 index 000000000..74421d47b --- /dev/null +++ b/proxy/common/logger.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import logging +from typing import Any, Optional + +from .constants import DEFAULT_LOG_FILE, DEFAULT_LOG_LEVEL, DEFAULT_LOG_FORMAT + + +SINGLE_CHAR_TO_LEVEL = { + 'D': 'DEBUG', + 'I': 'INFO', + 'W': 'WARNING', + 'E': 'ERROR', + 'C': 'CRITICAL', +} + + +def single_char_to_level(char: str) -> Any: + return getattr(logging, SINGLE_CHAR_TO_LEVEL[char.upper()[0]]) + + +class Logger: + """Common logging utilities and setup.""" + + @staticmethod + def setup( + log_file: Optional[str] = DEFAULT_LOG_FILE, + log_level: str = DEFAULT_LOG_LEVEL, + log_format: str = DEFAULT_LOG_FORMAT, + ) -> None: + if log_file: # pragma: no cover + logging.basicConfig( + filename=log_file, + filemode='a', + level=single_char_to_level(log_level), + format=log_format, + ) + else: + logging.basicConfig( + level=single_char_to_level(log_level), + format=log_format, + ) diff --git a/proxy/common/pki.py b/proxy/common/pki.py index 6361611b0..bdc2c5f3f 100644 --- a/proxy/common/pki.py +++ b/proxy/common/pki.py @@ -7,21 +7,28 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + pki """ +import os import sys +import time +import uuid +import logging import argparse +import tempfile import contextlib -import os -import uuid import subprocess -import tempfile -import logging -from typing import List, Generator, Optional, Tuple +from typing import List, Tuple, Optional, Generator from .utils import bytes_ -from .constants import COMMA from .version import __version__ +from .constants import COMMA + +logger = logging.getLogger(__name__) DEFAULT_CONFIG = b'''[ req ] @@ -54,13 +61,15 @@ def remove_passphrase( key_in_path: str, password: str, key_out_path: str, - timeout: int = 10) -> bool: + timeout: int = 10, + openssl: str = 'openssl', +) -> bool: """Remove passphrase from a private key.""" command = [ - 'openssl', 'rsa', + openssl, 'rsa', '-passin', 'pass:%s' % password, '-in', key_in_path, - '-out', key_out_path + '-out', key_out_path, ] return run_openssl_command(command, timeout) @@ -69,12 +78,14 @@ def gen_private_key( key_path: str, password: str, bits: int = 2048, - timeout: int = 10) -> bool: + timeout: int = 10, + openssl: str = 'openssl', +) -> bool: """Generates a private key.""" command = [ - 'openssl', 'genrsa', '-aes256', + openssl, 'genrsa', '-aes256', '-passout', 'pass:%s' % password, - '-out', key_path, str(bits) + '-out', key_path, str(bits), ] return run_openssl_command(command, timeout) @@ -87,15 +98,17 @@ def gen_public_key( alt_subj_names: Optional[List[str]] = None, extended_key_usage: Optional[str] = None, validity_in_days: int = 365, - timeout: int = 10) -> bool: + timeout: int = 10, + openssl: str = 'openssl', +) -> bool: """For a given private key, generates a corresponding public key.""" with ssl_config(alt_subj_names, extended_key_usage) as (config_path, has_extension): command = [ - 'openssl', 'req', '-new', '-x509', '-sha256', + openssl, 'req', '-new', '-x509', '-sha256', '-days', str(validity_in_days), '-subj', subject, '-passin', 'pass:%s' % private_key_password, '-config', config_path, - '-key', private_key_path, '-out', public_key_path + '-key', private_key_path, '-out', public_key_path, ] if has_extension: command.extend([ @@ -109,13 +122,15 @@ def gen_csr( key_path: str, password: str, crt_path: str, - timeout: int = 10) -> bool: + timeout: int = 10, + openssl: str = 'openssl', +) -> bool: """Generates a CSR based upon existing certificate and key file.""" command = [ - 'openssl', 'x509', '-x509toreq', + openssl, 'x509', '-x509toreq', '-passin', 'pass:%s' % password, '-in', crt_path, '-signkey', key_path, - '-out', csr_path + '-out', csr_path, ] return run_openssl_command(command, timeout) @@ -130,11 +145,13 @@ def sign_csr( alt_subj_names: Optional[List[str]] = None, extended_key_usage: Optional[str] = None, validity_in_days: int = 365, - timeout: int = 10) -> bool: + timeout: int = 10, + openssl: str = 'openssl', +) -> bool: """Sign a CSR using CA key and certificate.""" with ext_file(alt_subj_names, extended_key_usage) as extension_path: command = [ - 'openssl', 'x509', '-req', '-sha256', + openssl, 'x509', '-req', '-sha256', '-CA', ca_crt_path, '-CAkey', ca_key_path, '-passin', 'pass:%s' % ca_key_password, @@ -149,7 +166,8 @@ def sign_csr( def get_ext_config( alt_subj_names: Optional[List[str]] = None, - extended_key_usage: Optional[str] = None) -> bytes: + extended_key_usage: Optional[str] = None, +) -> bytes: config = b'' # Add SAN extension if alt_subj_names is not None and len(alt_subj_names) > 0: @@ -166,12 +184,14 @@ def get_ext_config( @contextlib.contextmanager def ext_file( alt_subj_names: Optional[List[str]] = None, - extended_key_usage: Optional[str] = None) -> Generator[str, None, None]: + extended_key_usage: Optional[str] = None, +) -> Generator[str, None, None]: # Write config to temp file config_path = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex) with open(config_path, 'wb') as cnf: cnf.write( - get_ext_config(alt_subj_names, extended_key_usage)) + get_ext_config(alt_subj_names, extended_key_usage), + ) yield config_path @@ -182,7 +202,8 @@ def ext_file( @contextlib.contextmanager def ssl_config( alt_subj_names: Optional[List[str]] = None, - extended_key_usage: Optional[str] = None) -> Generator[Tuple[str, bool], None, None]: + extended_key_usage: Optional[str] = None, +) -> Generator[Tuple[str, bool], None, None]: config = DEFAULT_CONFIG has_extension = False @@ -209,7 +230,7 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: cmd = subprocess.Popen( command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE + stderr=subprocess.PIPE, ) cmd.communicate(timeout=timeout) return cmd.returncode == 0 @@ -218,7 +239,7 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: if __name__ == '__main__': available_actions = ( 'remove_passphrase', 'gen_private_key', 'gen_public_key', - 'gen_csr', 'sign_csr' + 'gen_csr', 'sign_csr', ) parser = argparse.ArgumentParser( @@ -228,7 +249,7 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: 'action', type=str, default=None, - help='Valid actions: ' + ', '.join(available_actions) + help='Valid actions: ' + ', '.join(available_actions), ) parser.add_argument( '--password', @@ -251,30 +272,80 @@ def run_openssl_command(command: List[str], timeout: int) -> bool: parser.add_argument( '--subject', type=str, - default='/CN=example.com', - help='Subject to use for public key generation. Default: /CN=example.com', + default='/CN=localhost', + help='Subject to use for public key generation. Default: /CN=localhost', + ) + parser.add_argument( + '--csr-path', + type=str, + default=None, + help='CSR file path. Use with gen_csr and sign_csr action.', + ) + parser.add_argument( + '--crt-path', + type=str, + default=None, + help='Signed certificate path. Use with sign_csr action.', + ) + parser.add_argument( + '--hostname', + type=str, + default=None, + help='Alternative subject names to use during CSR signing.', + ) + parser.add_argument( + '--openssl', + type=str, + default='openssl', + help='Path to openssl binary. By default, we assume openssl is in your PATH', ) args = parser.parse_args(sys.argv[1:]) # Validation if args.action not in available_actions: - print('Invalid --action. Valid values ' + ', '.join(available_actions)) + logger.error( + 'Invalid --action. Valid values ' + + ', '.join(available_actions), + ) + sys.exit(1) + if args.action in ('gen_private_key', 'gen_public_key') and \ + args.private_key_path is None: + logger.error('--private-key-path is required for ' + args.action) + sys.exit(1) + if args.action == 'gen_public_key' and \ + args.public_key_path is None: + logger.error( + '--public-key-file is required for private key generation', + ) sys.exit(1) - if args.action in ('gen_private_key', 'gen_public_key'): - if args.private_key_path is None: - print('--private-key-path is required for ' + args.action) - sys.exit(1) - if args.action == 'gen_public_key': - if args.public_key_path is None: - print('--public-key-file is required for private key generation') - sys.exit(1) # Execute if args.action == 'gen_private_key': - gen_private_key(args.private_key_path, args.password) + gen_private_key( + args.private_key_path, + args.password, openssl=args.openssl, + ) elif args.action == 'gen_public_key': - gen_public_key(args.public_key_path, args.private_key_path, - args.password, args.subject) + gen_public_key( + args.public_key_path, args.private_key_path, + args.password, args.subject, openssl=args.openssl, + ) elif args.action == 'remove_passphrase': - remove_passphrase(args.private_key_path, args.password, - args.private_key_path) + remove_passphrase( + args.private_key_path, args.password, + args.private_key_path, openssl=args.openssl, + ) + elif args.action == 'gen_csr': + gen_csr( + args.csr_path, + args.private_key_path, + args.password, + args.public_key_path, + openssl=args.openssl, + ) + elif args.action == 'sign_csr': + sign_csr( + args.csr_path, args.crt_path, args.private_key_path, args.password, + args.public_key_path, str(int(time.time())), alt_subj_names=[args.hostname], + openssl=args.openssl, + ) diff --git a/proxy/common/plugins.py b/proxy/common/plugins.py new file mode 100644 index 000000000..c919154cc --- /dev/null +++ b/proxy/common/plugins.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import os +import inspect +import logging +import importlib +import itertools +from typing import Any, Dict, List, Tuple, Union, Optional + +from .utils import text_, bytes_ +from .constants import DOT, COMMA, DEFAULT_ABC_PLUGINS + + +logger = logging.getLogger(__name__) + + +class Plugins: + """Common utilities for plugin discovery.""" + + @staticmethod + def resolve_plugin_flag(flag_plugins: Any, opt_plugins: Optional[Any] = None) -> List[Union[bytes, type]]: + if isinstance(flag_plugins, list): + requested_plugins = list( + itertools.chain.from_iterable([ + p.split(text_(COMMA)) for p in list( + itertools.chain.from_iterable(flag_plugins), + ) + ]), + ) + else: + requested_plugins = flag_plugins.split(text_(COMMA)) + return [ + p if isinstance(p, type) else bytes_(p) + for p in (opt_plugins if opt_plugins is not None else requested_plugins) + if not (isinstance(p, str) and len(p) == 0) + ] + + @staticmethod + def discover(input_args: List[str]) -> None: + """Search for external plugin found in command line arguments, + then iterates over each value and discover/import the plugin. + """ + for i, f in enumerate(input_args): + if f in ('--plugin', '--plugins', '--auth-plugin'): + v = input_args[i + 1] + parts = v.split(',') + for part in parts: + Plugins.importer(bytes_(part)) + + @staticmethod + def load( + plugins: List[Union[bytes, type]], + abc_plugins: Optional[List[str]] = None, + ) -> Dict[bytes, List[type]]: + """Accepts a list Python modules, scans them to identify + if they are an implementation of abstract plugin classes and + returns a dictionary of matching plugins for each abstract class. + """ + p: Dict[bytes, List[type]] = {} + for abc_plugin in (abc_plugins or DEFAULT_ABC_PLUGINS): + p[bytes_(abc_plugin)] = [] + for plugin_ in plugins: + klass, module_name = Plugins.importer(plugin_) + assert klass and module_name + mro = list(inspect.getmro(klass)) + # Find the base plugin class that + # this plugin_ is implementing + base_klass = None + for k in mro: + if bytes_(k.__name__) in p: + base_klass = k + break + if base_klass is None: + raise ValueError('%s is NOT a valid plugin' % text_(plugin_)) + if klass not in p[bytes_(base_klass.__name__)]: + p[bytes_(base_klass.__name__)].append(klass) + logger.info('Loaded plugin %s.%s', module_name, klass.__name__) + # print(p) + return p + + @staticmethod + def importer(plugin: Union[bytes, type]) -> Tuple[type, str]: + """Import and returns the plugin.""" + if isinstance(plugin, type): + return (plugin, '__main__') + plugin_ = text_(plugin.strip()) + assert plugin_ != '' + module_name, klass_name = plugin_.rsplit(text_(DOT), 1) + klass = getattr( + importlib.import_module( + module_name.replace( + os.path.sep, text_(DOT), + ), + ), + klass_name, + ) + return (klass, module_name) diff --git a/proxy/common/types.py b/proxy/common/types.py index c41104844..984cc3bdd 100644 --- a/proxy/common/types.py +++ b/proxy/common/types.py @@ -8,18 +8,34 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import re +import ssl +import sys import queue +import socket +import ipaddress +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union -from typing import TYPE_CHECKING, Dict, Any -from typing_extensions import Protocol - -if TYPE_CHECKING: - DictQueueType = queue.Queue[Dict[str, Any]] # pragma: no cover +if TYPE_CHECKING: # pragma: no cover + DictQueueType = queue.Queue[Dict[str, Any]] else: DictQueueType = queue.Queue -class HasFileno(Protocol): - def fileno(self) -> int: - ... # pragma: no cover +Selectable = int +Selectables = List[Selectable] +SelectableEvents = Dict[Selectable, int] # Values are event masks +Readables = Selectables +Writables = Selectables +Descriptors = Tuple[Readables, Writables] +IpAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] +TcpOrTlsSocket = Union[ssl.SSLSocket, socket.socket] +HostPort = Tuple[str, int] + +if sys.version_info.minor == 6: + RePattern = Any +elif sys.version_info.minor in (7, 8): + RePattern = re.Pattern # type: ignore +else: + RePattern = re.Pattern[Any] # type: ignore diff --git a/proxy/common/utils.py b/proxy/common/utils.py index ccd2532ad..4e2eed814 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -7,16 +7,54 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + utils """ -import contextlib +import ssl +import sys +import socket +import logging +import argparse import functools import ipaddress -import socket - +import contextlib from types import TracebackType -from typing import Optional, Dict, Any, List, Tuple, Type, Callable +from typing import Any, Dict, List, Type, Tuple, Callable, Optional + +from .types import HostPort +from .constants import ( + CRLF, COLON, HTTP_1_1, IS_WINDOWS, WHITESPACE, DEFAULT_TIMEOUT, + DEFAULT_THREADLESS, PROXY_AGENT_HEADER_VALUE, +) + + +if not IS_WINDOWS: # pragma: no cover + import resource + +logger = logging.getLogger(__name__) + + +def tls_interception_enabled(flags: argparse.Namespace) -> bool: + return flags.ca_key_file is not None and \ + flags.ca_cert_dir is not None and \ + flags.ca_signing_key_file is not None and \ + flags.ca_cert_file is not None + -from .constants import HTTP_1_1, COLON, WHITESPACE, CRLF, DEFAULT_TIMEOUT +def is_threadless(threadless: bool, threaded: bool) -> bool: + # if default is threadless then return true unless + # user has overridden mode using threaded flag. + # + # if default is not threadless then return true + # only if user has overridden using --threadless flag + return (DEFAULT_THREADLESS and not threaded) or (not DEFAULT_THREADLESS and threadless) + + +def is_py2() -> bool: + """Exists only to avoid mocking :data:`sys.version_info` in tests.""" + return sys.version_info.major == 2 def text_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: @@ -43,40 +81,60 @@ def bytes_(s: Any, encoding: str = 'utf-8', errors: str = 'strict') -> Any: return s -def build_http_request(method: bytes, url: bytes, - protocol_version: bytes = HTTP_1_1, - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None) -> bytes: +def build_http_request( + method: bytes, url: bytes, + protocol_version: bytes = HTTP_1_1, + content_type: Optional[bytes] = None, + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, + conn_close: bool = False, + no_ua: bool = False, +) -> bytes: """Build and returns a HTTP request packet.""" - if headers is None: - headers = {} + headers = headers or {} + if content_type is not None: + headers[b'Content-Type'] = content_type + has_transfer_encoding = False + has_user_agent = False + for k, _ in headers.items(): + if k.lower() == b'transfer-encoding': + has_transfer_encoding = True + elif k.lower() == b'user-agent': + has_user_agent = True + if body and not has_transfer_encoding: + headers[b'Content-Length'] = bytes_(len(body)) + if not has_user_agent and not no_ua: + headers[b'User-Agent'] = PROXY_AGENT_HEADER_VALUE return build_http_pkt( - [method, url, protocol_version], headers, body) + [method, url, protocol_version], + headers, + body, + conn_close, + ) -def build_http_response(status_code: int, - protocol_version: bytes = HTTP_1_1, - reason: Optional[bytes] = None, - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None) -> bytes: +def build_http_response( + status_code: int, + protocol_version: bytes = HTTP_1_1, + reason: Optional[bytes] = None, + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, + conn_close: bool = False, + no_cl: bool = False, +) -> bytes: """Build and returns a HTTP response packet.""" line = [protocol_version, bytes_(status_code)] if reason: line.append(reason) - if headers is None: - headers = {} - has_content_length = False + headers = headers or {} has_transfer_encoding = False - for k in headers: - if k.lower() == b'content-length': - has_content_length = True + for k, _ in headers.items(): if k.lower() == b'transfer-encoding': has_transfer_encoding = True - if body is not None and \ - not has_transfer_encoding and \ - not has_content_length: - headers[b'Content-Length'] = bytes_(len(body)) - return build_http_pkt(line, headers, body) + break + if not has_transfer_encoding and not no_cl: + headers[b'Content-Length'] = bytes_(len(body)) if body else b'0' + return build_http_pkt(line, headers, body, conn_close) def build_http_header(k: bytes, v: bytes) -> bytes: @@ -84,24 +142,31 @@ def build_http_header(k: bytes, v: bytes) -> bytes: return k + COLON + WHITESPACE + v -def build_http_pkt(line: List[bytes], - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None) -> bytes: +def build_http_pkt( + line: List[bytes], + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, + conn_close: bool = False, +) -> bytes: """Build and returns a HTTP request or response packet.""" - req = WHITESPACE.join(line) + CRLF - if headers is not None: - for k in headers: - req += build_http_header(k, headers[k]) + CRLF - req += CRLF + pkt = WHITESPACE.join(line) + CRLF + headers = headers or {} + if conn_close: + headers[b'Connection'] = b'close' + for k, v in headers.items(): + pkt += build_http_header(k, v) + CRLF + pkt += CRLF if body: - req += body - return req + pkt += body + return pkt def build_websocket_handshake_request( key: bytes, method: bytes = b'GET', - url: bytes = b'/') -> bytes: + url: bytes = b'/', + host: bytes = b'localhost', +) -> bytes: """ Build and returns a Websocket handshake request packet. @@ -112,11 +177,12 @@ def build_websocket_handshake_request( return build_http_request( method, url, headers={ + b'Host': host, b'Connection': b'upgrade', b'Upgrade': b'websocket', b'Sec-WebSocket-Key': key, b'Sec-WebSocket-Version': b'13', - } + }, ) @@ -131,8 +197,8 @@ def build_websocket_handshake_response(accept: bytes) -> bytes: headers={ b'Upgrade': b'websocket', b'Connection': b'Upgrade', - b'Sec-WebSocket-Accept': accept - } + b'Sec-WebSocket-Accept': accept, + }, ) @@ -140,27 +206,53 @@ def find_http_line(raw: bytes) -> Tuple[Optional[bytes], bytes]: """Find and returns first line ending in CRLF along with following buffer. If no ending CRLF is found, line is None.""" - pos = raw.find(CRLF) - if pos == -1: - return None, raw - line = raw[:pos] - rest = raw[pos + len(CRLF):] - return line, rest + parts = raw.split(CRLF, 1) + return (None, raw) \ + if len(parts) == 1 \ + else (parts[0], parts[1]) + + +def wrap_socket( + conn: socket.socket, + keyfile: str, + certfile: str, + cafile: Optional[str] = None, +) -> ssl.SSLSocket: + """Use this to upgrade server_side socket to TLS.""" + ctx = ssl.create_default_context( + ssl.Purpose.CLIENT_AUTH, + cafile=cafile, + ) + ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + ctx.verify_mode = ssl.CERT_NONE + ctx.load_cert_chain( + certfile=certfile, + keyfile=keyfile, + ) + return ctx.wrap_socket( + conn, + server_side=True, + ) def new_socket_connection( - addr: Tuple[str, int], timeout: int = DEFAULT_TIMEOUT) -> socket.socket: + addr: HostPort, + timeout: float = DEFAULT_TIMEOUT, + source_address: Optional[HostPort] = None, +) -> socket.socket: conn = None try: ip = ipaddress.ip_address(addr[0]) if ip.version == 4: conn = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, 0) + socket.AF_INET, socket.SOCK_STREAM, 0, + ) conn.settimeout(timeout) conn.connect(addr) else: conn = socket.socket( - socket.AF_INET6, socket.SOCK_STREAM, 0) + socket.AF_INET6, socket.SOCK_STREAM, 0, + ) conn.settimeout(timeout) conn.connect((addr[0], addr[1], 0, 0)) except ValueError: @@ -170,14 +262,14 @@ def new_socket_connection( return conn # try to establish dual stack IPv4/IPv6 connection. - return socket.create_connection(addr, timeout=timeout) + return socket.create_connection(addr, timeout=timeout, source_address=source_address) class socket_connection(contextlib.ContextDecorator): """Same as new_socket_connection but as a context manager and decorator.""" - def __init__(self, addr: Tuple[str, int]): - self.addr: Tuple[str, int] = addr + def __init__(self, addr: HostPort): + self.addr: HostPort = addr self.conn: Optional[socket.socket] = None super().__init__() @@ -189,12 +281,14 @@ def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + exc_tb: Optional[TracebackType], + ) -> None: if self.conn: self.conn.close() - def __call__(self, func: Callable[..., Any] - ) -> Callable[[Tuple[Any, ...], Dict[str, Any]], Any]: + def __call__( # type: ignore + self, func: Callable[..., Any], + ) -> Callable[[Tuple[Any, ...], Dict[str, Any]], Any]: @functools.wraps(func) def decorated(*args: Any, **kwargs: Any) -> Any: with self as conn: @@ -210,3 +304,19 @@ def get_available_port() -> int: return int(port) +def set_open_file_limit(soft_limit: int) -> None: + """Configure open file description soft limit on supported OS.""" + # resource module not available on Windows OS + if IS_WINDOWS: # pragma: no cover + return + + curr_soft_limit, curr_hard_limit = resource.getrlimit( + resource.RLIMIT_NOFILE, + ) + if curr_soft_limit < soft_limit < curr_hard_limit: + resource.setrlimit( + resource.RLIMIT_NOFILE, (soft_limit, curr_hard_limit), + ) + logger.debug( + 'Open file soft limit set to %d', soft_limit, + ) diff --git a/proxy/common/version.py b/proxy/common/version.py index e8af26932..530829c5a 100644 --- a/proxy/common/version.py +++ b/proxy/common/version.py @@ -8,5 +8,7 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -VERSION = (2, 2, 0) -__version__ = '.'.join(map(str, VERSION[0:3])) +from ._version import VERSION, __version__ # noqa: WPS436 + + +__all__ = '__version__', 'VERSION' diff --git a/proxy/common_neon/solana_tx_error_parser.py b/proxy/common_neon/solana_tx_error_parser.py index ffbddd0bc..52c4b3a13 100644 --- a/proxy/common_neon/solana_tx_error_parser.py +++ b/proxy/common_neon/solana_tx_error_parser.py @@ -56,9 +56,7 @@ class SolTxErrorParser: f'Program log: {EVM_LOADER_ID}' + r':\d+ : Invalid Ethereum transaction nonce: acc (\d+), trx (\d+)' ) - _already_finalized_re = re.compile( - r'Program log: program/src/instruction/transaction_step_from_account.rs:\d+ : Transaction already finalized' - ) + _already_finalized = f'Program {EVM_LOADER_ID} failed: custom program error: 0x4' def __init__(self, receipt: Union[SolTxReceipt, BaseException, str]): assert isinstance(receipt, dict) or isinstance(receipt, BaseException) or isinstance(receipt, str) @@ -227,8 +225,7 @@ def check_if_account_already_exists(self) -> bool: def check_if_already_finalized(self) -> bool: log_list = self.get_log_list() for log in log_list: - m = self._already_finalized_re.search(log) - if m is not None: + if log == self._already_finalized: return True return False diff --git a/proxy/core/__init__.py b/proxy/core/__init__.py index 232621f0b..ae3ea4267 100644 --- a/proxy/core/__init__.py +++ b/proxy/core/__init__.py @@ -7,4 +7,8 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + Subpackages """ diff --git a/proxy/core/acceptor/__init__.py b/proxy/core/acceptor/__init__.py index 9c0a97b33..4fd307763 100644 --- a/proxy/core/acceptor/__init__.py +++ b/proxy/core/acceptor/__init__.py @@ -7,9 +7,14 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + pre """ -from .acceptor import Acceptor from .pool import AcceptorPool +from .acceptor import Acceptor + __all__ = [ 'Acceptor', diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 6a6ff837c..299c0efbd 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -7,134 +7,251 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + pre """ +import socket import logging -import multiprocessing -import multiprocessing.synchronize +import argparse import selectors -import socket import threading -# import time +import multiprocessing +import multiprocessing.synchronize +from typing import Dict, List, Tuple, Optional from multiprocessing import connection -from multiprocessing.reduction import send_handle, recv_handle -from typing import Optional, Type, Tuple +from multiprocessing.reduction import recv_handle + +from ..work import start_threaded_work, delegate_work_to_pool +from ..event import EventQueue +from ..work.fd import LocalFdExecutor +from ...common.flag import flags +from ...common.types import HostPort +from ...common.logger import Logger +from ...common.backports import NonBlockingQueue +from ...common.constants import DEFAULT_LOCAL_EXECUTOR + + +logger = logging.getLogger(__name__) -from ..connection import TcpClientConnection -from ..threadless import ThreadlessWork, Threadless -from ..event import EventQueue, eventNames -from ...common.flags import Flags -from logged_groups import logged_group + +flags.add_argument( + '--local-executor', + type=int, + default=int(DEFAULT_LOCAL_EXECUTOR), + help='Default: ' + ('1' if DEFAULT_LOCAL_EXECUTOR else '0') + '. ' + + 'Enabled by default. Use 0 to disable. When enabled acceptors ' + + 'will make use of local (same process) executor instead of distributing load across ' + + 'remote (other process) executors. Enable this option to achieve CPU affinity between ' + + 'acceptors and executors, instead of using underlying OS kernel scheduling algorithm.', +) -@logged_group("neon.Acceptor") class Acceptor(multiprocessing.Process): - """Socket client acceptor. + """Work acceptor process. + + On start-up, `Acceptor` accepts a file descriptor which will be used to + accept new work. File descriptor is accepted over a `fd_queue`. + + `Acceptor` goes on to listen for new work over the received server socket. + By default, `Acceptor` will spawn a new thread to handle each work. - Accepts client connection over received server socket handle and - starts a new work thread. + However, when ``--threadless`` option is enabled without ``--local-executor``, + `Acceptor` process will also pre-spawns a + :class:`~proxy.core.acceptor.threadless.Threadless` process during start-up. + Accepted work is delegated to these :class:`~proxy.core.acceptor.threadless.Threadless` + processes. `Acceptor` process shares accepted work with a + :class:`~proxy.core.acceptor.threadless.Threadless` process over it's dedicated pipe. """ def __init__( self, idd: int, - work_queue: connection.Connection, - flags: Flags, - work_klass: Type[ThreadlessWork], - lock: multiprocessing.synchronize.Lock, - event_queue: Optional[EventQueue] = None) -> None: + fd_queue: connection.Connection, + flags: argparse.Namespace, + lock: 'multiprocessing.synchronize.Lock', + # semaphore: multiprocessing.synchronize.Semaphore, + executor_queues: List[connection.Connection], + executor_pids: List[int], + executor_locks: List['multiprocessing.synchronize.Lock'], + event_queue: Optional[EventQueue] = None, + ) -> None: super().__init__() - self.idd = idd - self.work_queue: connection.Connection = work_queue self.flags = flags - self.work_klass = work_klass - self.lock = lock + # Eventing core queue self.event_queue = event_queue - + # Index assigned by `AcceptorPool` + self.idd = idd + # Mutex used for synchronization with acceptors + self.lock = lock + # self.semaphore = semaphore + # Queue over which server socket fd is received on start-up + self.fd_queue: connection.Connection = fd_queue + # Available executors + self.executor_queues = executor_queues + self.executor_pids = executor_pids + self.executor_locks = executor_locks + # Selector self.running = multiprocessing.Event() self.selector: Optional[selectors.DefaultSelector] = None - self.sock: Optional[socket.socket] = None - self.threadless_process: Optional[Threadless] = None - self.threadless_client_queue: Optional[connection.Connection] = None - - def start_threadless_process(self) -> None: - pipe = multiprocessing.Pipe() - self.threadless_client_queue = pipe[0] - self.threadless_process = Threadless( - client_queue=pipe[1], - flags=self.flags, - work_klass=self.work_klass, - event_queue=self.event_queue - ) - self.threadless_process.start() - logger.debug('Started process %d', self.threadless_process.pid) - - def shutdown_threadless_process(self) -> None: - assert self.threadless_process and self.threadless_client_queue - logger.debug('Stopped process %d', self.threadless_process.pid) - self.threadless_process.running.set() - self.threadless_process.join() - self.threadless_client_queue.close() - - def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None: - if self.flags.threadless and \ - self.threadless_client_queue and \ - self.threadless_process: - self.threadless_client_queue.send(addr) - send_handle( - self.threadless_client_queue, - conn.fileno(), - self.threadless_process.pid - ) - conn.close() - else: - work = self.work_klass( - TcpClientConnection(conn, addr), - flags=self.flags, - event_queue=self.event_queue - ) - work_thread = threading.Thread(target=work.run) - work_thread.daemon = True - work.publish_event( - event_name=eventNames.WORK_STARTED, - event_payload={'fileno': conn.fileno(), 'addr': addr}, - publisher_id=self.__class__.__name__ - ) - work_thread.start() + # File descriptors used to accept new work + self.socks: Dict[int, socket.socket] = {} + # Internals + self._total: Optional[int] = None + self._local_work_queue: Optional['NonBlockingQueue'] = None + self._local: Optional[LocalFdExecutor] = None + self._lthread: Optional[threading.Thread] = None + + def accept( + self, + events: List[Tuple[selectors.SelectorKey, int]], + ) -> List[Tuple[socket.socket, Optional[HostPort]]]: + works = [] + for key, mask in events: + if mask & selectors.EVENT_READ: + try: + conn, addr = self.socks[key.data].accept() + logging.debug( + 'Accepting new work#{0}'.format(conn.fileno()), + ) + works.append((conn, addr or None)) + except BlockingIOError: + # logger.info('blocking io error') + pass + return works def run_once(self) -> None: - with self.lock: - assert self.selector and self.sock + if self.selector is not None: events = self.selector.select(timeout=1) if len(events) == 0: return - conn, addr = self.sock.accept() - conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - # now = time.time() - # fileno: int = conn.fileno() - self.start_work(conn, addr) - # logger.info('Work started for fd %d in %f seconds', fileno, time.time() - now) + # locked = False + # try: + # if self.lock.acquire(block=False): + # locked = True + # self.semaphore.release() + # finally: + # if locked: + # self.lock.release() + locked, works = False, [] + try: + # if not self.semaphore.acquire(False, None): + # return + if self.lock.acquire(block=False): + locked = True + works = self.accept(events) + finally: + if locked: + self.lock.release() + for work in works: + if self.flags.threadless and \ + self.flags.local_executor: + assert self._local_work_queue + self._local_work_queue.put(work) + else: + self._work(*work) def run(self) -> None: - self.selector = selectors.DefaultSelector() - fileno = recv_handle(self.work_queue) - self.work_queue.close() - self.sock = socket.fromfd( - fileno, - family=self.flags.family, - type=socket.SOCK_STREAM + Logger.setup( + self.flags.log_file, self.flags.log_level, + self.flags.log_format, ) + self.selector = selectors.DefaultSelector() try: - self.selector.register(self.sock, selectors.EVENT_READ) - if self.flags.threadless: - self.start_threadless_process() + self._recv_and_setup_socks() + if self.flags.threadless and self.flags.local_executor: + self._start_local() + for fileno in self.socks: + self.selector.register( + fileno, selectors.EVENT_READ, fileno, + ) while not self.running.is_set(): self.run_once() except KeyboardInterrupt: pass finally: - self.selector.unregister(self.sock) - if self.flags.threadless: - self.shutdown_threadless_process() - self.sock.close() + for fileno in self.socks: + self.selector.unregister(fileno) + if self.flags.threadless and self.flags.local_executor: + self._stop_local() + for fileno in self.socks: + self.socks[fileno].close() + self.socks.clear() + self.selector.close() logger.debug('Acceptor#%d shutdown', self.idd) + + def _recv_and_setup_socks(self) -> None: + # TODO: Use selector on fd_queue so that we can + # dynamically accept from new fds. + for _ in range(self.fd_queue.recv()): + fileno = recv_handle(self.fd_queue) + # TODO: Convert to socks i.e. list of fds + self.socks[fileno] = socket.fromfd( + fileno, + family=self.flags.family, + type=socket.SOCK_STREAM, + ) + self.fd_queue.close() + + def _start_local(self) -> None: + assert self.socks + self._local_work_queue = NonBlockingQueue() + self._local = LocalFdExecutor( + iid=self.idd, + work_queue=self._local_work_queue, + flags=self.flags, + event_queue=self.event_queue, + ) + self._lthread = threading.Thread(target=self._local.run) + self._lthread.daemon = True + self._lthread.start() + + def _stop_local(self) -> None: + if self._lthread is not None and \ + self._local_work_queue is not None: + self._local_work_queue.put(False) + self._lthread.join() + + def _work(self, conn: socket.socket, addr: Optional[HostPort]) -> None: + self._total = self._total or 0 + if self.flags.threadless: + # Index of worker to which this work should be dispatched + # Use round-robin strategy by default. + # + # By default all acceptors will start sending work to + # 1st workers. To randomize, we offset index by idd. + index = (self._total + self.idd) % self.flags.num_workers + thread = threading.Thread( + target=delegate_work_to_pool, + args=( + self.executor_pids[index], + self.executor_queues[index], + self.executor_locks[index], + conn, + addr, + self.flags.unix_socket_path, + ), + ) + thread.start() + # TODO: Move me into target method + logger.debug( # pragma: no cover + 'Dispatched work#{0}.{1}.{2} to worker#{3}'.format( + conn.fileno(), self.idd, self._total, index, + ), + ) + else: + _, thread = start_threaded_work( + self.flags, + conn, + addr, + event_queue=self.event_queue, + publisher_id=self.__class__.__name__, + ) + # TODO: Move me into target method + logger.debug( # pragma: no cover + 'Started work#{0}.{1}.{2} in thread#{3}'.format( + conn.fileno(), self.idd, self._total, thread.ident, + ), + ) + self._total += 1 diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 48cedaf96..09fb9f447 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -7,127 +7,148 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + acceptor + acceptors + pre """ import logging +import argparse import multiprocessing -import socket -import threading -# import time +from typing import TYPE_CHECKING, Any, List, Optional from multiprocessing import connection from multiprocessing.reduction import send_handle -from typing import List, Optional, Type from .acceptor import Acceptor -from ..threadless import ThreadlessWork -from ..event import EventQueue, EventDispatcher -from ...common.flags import Flags +from ..listener import ListenerPool +from ...common.flag import flags +from ...common.constants import DEFAULT_NUM_ACCEPTORS + + +if TYPE_CHECKING: # pragma: no cover + from ..event import EventQueue logger = logging.getLogger(__name__) -LOCK = multiprocessing.Lock() + +flags.add_argument( + '--num-acceptors', + type=int, + default=DEFAULT_NUM_ACCEPTORS, + help='Defaults to number of CPU cores.', +) class AcceptorPool: - """AcceptorPool. + """AcceptorPool is a helper class which pre-spawns + :py:class:`~proxy.core.acceptor.acceptor.Acceptor` processes to + utilize all available CPU cores for accepting new work. - Pre-spawns worker processes to utilize all cores available on the system. Server socket connection is - dispatched over a pipe to workers. Each worker accepts incoming client request and spawns a - separate thread to handle the client request. + A file descriptor to consume work from is shared with + :py:class:`~proxy.core.acceptor.acceptor.Acceptor` processes over a + pipe. Each :py:class:`~proxy.core.acceptor.acceptor.Acceptor` + process then concurrently accepts new work over the shared file + descriptor. + + Example usage: + + with AcceptorPool(flags=...) as pool: + while True: + time.sleep(1) + + `flags.work_klass` must implement :py:class:`~proxy.core.work.Work` class. """ - def __init__(self, flags: Flags, work_klass: Type[ThreadlessWork]) -> None: + def __init__( + self, + flags: argparse.Namespace, + listeners: ListenerPool, + executor_queues: List[connection.Connection], + executor_pids: List[int], + executor_locks: List['multiprocessing.synchronize.Lock'], + event_queue: Optional['EventQueue'] = None, + ) -> None: self.flags = flags - self.socket: Optional[socket.socket] = None + # File descriptor to use for accepting new work + self.listeners: ListenerPool = listeners + # Available executors + self.executor_queues: List[connection.Connection] = executor_queues + self.executor_pids: List[int] = executor_pids + self.executor_locks: List['multiprocessing.synchronize.Lock'] = executor_locks + # Eventing core queue + self.event_queue: Optional['EventQueue'] = event_queue + # Acceptor process instances self.acceptors: List[Acceptor] = [] - self.work_queues: List[connection.Connection] = [] - self.work_klass = work_klass - - self.event_queue: Optional[EventQueue] = None - self.event_dispatcher: Optional[EventDispatcher] = None - self.event_dispatcher_thread: Optional[threading.Thread] = None - self.event_dispatcher_shutdown: Optional[threading.Event] = None - self.manager: Optional[multiprocessing.managers.SyncManager] = None - - if self.flags.enable_events: - self.manager = multiprocessing.Manager() - self.event_queue = EventQueue(self.manager.Queue()) - - def listen(self) -> None: - self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind((str(self.flags.hostname), self.flags.port)) - self.socket.listen(self.flags.backlog) - self.socket.setblocking(False) + # Fd queues used to share file descriptor with acceptor processes + self.fd_queues: List[connection.Connection] = [] + # Internals + self.lock = multiprocessing.Lock() + # self.semaphore = multiprocessing.Semaphore(0) + + def __enter__(self) -> 'AcceptorPool': + self.setup() + return self + + def __exit__(self, *args: Any) -> None: + self.shutdown() + + def setup(self) -> None: + """Setup acceptors.""" + self._start() + execution_mode = ( + 'threadless (local)' + if self.flags.local_executor + else 'threadless (remote)' + ) if self.flags.threadless else 'threaded' logger.info( - 'Listening on %s:%d' % - (self.flags.hostname, self.flags.port)) + 'Started %d acceptors in %s mode' % ( + self.flags.num_acceptors, + execution_mode, + ), + ) + # Send file descriptor to all acceptor processes. + for index in range(self.flags.num_acceptors): + self.fd_queues[index].send(len(self.listeners.pool)) + for listener in self.listeners.pool: + fd = listener.fileno() + assert fd is not None + send_handle( + self.fd_queues[index], + fd, + self.acceptors[index].pid, + ) + self.fd_queues[index].close() + + def shutdown(self) -> None: + logger.info('Shutting down %d acceptors' % self.flags.num_acceptors) + for acceptor in self.acceptors: + acceptor.running.set() + for acceptor in self.acceptors: + acceptor.join() + logger.debug('Acceptors shutdown') - def start_workers(self) -> None: - """Start worker processes.""" - for acceptor_id in range(self.flags.num_workers): + def _start(self) -> None: + """Start acceptor processes.""" + for acceptor_id in range(self.flags.num_acceptors): work_queue = multiprocessing.Pipe() acceptor = Acceptor( idd=acceptor_id, - work_queue=work_queue[1], + fd_queue=work_queue[1], flags=self.flags, - work_klass=self.work_klass, - lock=LOCK, + lock=self.lock, + # semaphore=self.semaphore, event_queue=self.event_queue, + executor_queues=self.executor_queues, + executor_pids=self.executor_pids, + executor_locks=self.executor_locks, ) acceptor.start() logger.debug( 'Started acceptor#%d process %d', acceptor_id, - acceptor.pid) - self.acceptors.append(acceptor) - self.work_queues.append(work_queue[0]) - logger.info('Started %d workers' % self.flags.num_workers) - - def start_event_dispatcher(self) -> None: - self.event_dispatcher_shutdown = threading.Event() - assert self.event_dispatcher_shutdown - assert self.event_queue - self.event_dispatcher = EventDispatcher( - shutdown=self.event_dispatcher_shutdown, - event_queue=self.event_queue - ) - self.event_dispatcher_thread = threading.Thread( - target=self.event_dispatcher.run - ) - self.event_dispatcher_thread.start() - logger.debug('Thread ID: %d', self.event_dispatcher_thread.ident) - - def shutdown(self) -> None: - logger.info('Shutting down %d workers' % self.flags.num_workers) - for acceptor in self.acceptors: - acceptor.running.set() - if self.flags.enable_events: - assert self.event_dispatcher_shutdown - assert self.event_dispatcher_thread - self.event_dispatcher_shutdown.set() - self.event_dispatcher_thread.join() - logger.debug( - 'Shutdown of global event dispatcher thread %d successful', - self.event_dispatcher_thread.ident) - for acceptor in self.acceptors: - acceptor.join() - logger.debug('Acceptors shutdown') - - def setup(self) -> None: - """Listen on port, setup workers and pass server socket to workers.""" - self.listen() - if self.flags.enable_events: - logger.info('Core Event enabled') - self.start_event_dispatcher() - self.start_workers() - - # Send server socket to all acceptor processes. - assert self.socket is not None - for index in range(self.flags.num_workers): - send_handle( - self.work_queues[index], - self.socket.fileno(), - self.acceptors[index].pid + acceptor.pid, ) - self.work_queues[index].close() - self.socket.close() + self.acceptors.append(acceptor) + self.fd_queues.append(work_queue[0]) diff --git a/proxy/core/base/__init__.py b/proxy/core/base/__init__.py new file mode 100644 index 000000000..5ce5a827d --- /dev/null +++ b/proxy/core/base/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from .tcp_server import BaseTcpServerHandler +from .tcp_tunnel import BaseTcpTunnelHandler +from .tcp_upstream import TcpUpstreamConnectionHandler + + +__all__ = [ + 'BaseTcpServerHandler', + 'BaseTcpTunnelHandler', + 'TcpUpstreamConnectionHandler', +] diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py new file mode 100644 index 000000000..842e255a7 --- /dev/null +++ b/proxy/core/base/tcp_server.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + tcp +""" +import socket +import logging +import selectors +from abc import abstractmethod +from typing import Any, TypeVar, Optional + +from ...core.work import Work +from ...common.flag import flags +from ...common.types import ( + Readables, Writables, TcpOrTlsSocket, SelectableEvents, +) +from ...common.utils import wrap_socket +from ...core.connection import TcpClientConnection +from ...common.constants import ( + DEFAULT_TIMEOUT, DEFAULT_KEY_FILE, DEFAULT_CERT_FILE, + DEFAULT_MAX_SEND_SIZE, DEFAULT_CLIENT_RECVBUF_SIZE, + DEFAULT_SERVER_RECVBUF_SIZE, +) + + +logger = logging.getLogger(__name__) + + +flags.add_argument( + '--key-file', + type=str, + default=DEFAULT_KEY_FILE, + help='Default: None. Server key file to enable end-to-end TLS encryption with clients. ' + 'If used, must also pass --cert-file.', +) + +flags.add_argument( + '--cert-file', + type=str, + default=DEFAULT_CERT_FILE, + help='Default: None. Server certificate to enable end-to-end TLS encryption with clients. ' + 'If used, must also pass --key-file.', +) + +flags.add_argument( + '--client-recvbuf-size', + type=int, + default=DEFAULT_CLIENT_RECVBUF_SIZE, + help='Default: ' + str(int(DEFAULT_CLIENT_RECVBUF_SIZE / 1024)) + + ' KB. Maximum amount of data received from the ' + 'client in a single recv() operation.', +) + +flags.add_argument( + '--server-recvbuf-size', + type=int, + default=DEFAULT_SERVER_RECVBUF_SIZE, + help='Default: ' + str(int(DEFAULT_SERVER_RECVBUF_SIZE / 1024)) + + ' KB. Maximum amount of data received from the ' + 'server in a single recv() operation.', +) + +flags.add_argument( + '--max-sendbuf-size', + type=int, + default=DEFAULT_MAX_SEND_SIZE, + help='Default: ' + str(int(DEFAULT_MAX_SEND_SIZE / 1024)) + + ' KB. Maximum amount of data to flush in a single send() operation.', +) + +flags.add_argument( + '--timeout', + type=int, + default=DEFAULT_TIMEOUT, + help='Default: ' + str(DEFAULT_TIMEOUT) + + '. Number of seconds after which ' + 'an inactive connection must be dropped. Inactivity is defined by no ' + 'data sent or received by the client.', +) + + +T = TypeVar('T', bound=TcpClientConnection) + + +class BaseTcpServerHandler(Work[T]): + """BaseTcpServerHandler implements Work interface. + + BaseTcpServerHandler lifecycle is controlled by Threadless core + using asyncio. If you want to also support threaded mode, also + implement the optional run() method from Work class. + + An instance of BaseTcpServerHandler is created for each client + connection. BaseTcpServerHandler ensures that server is always + ready to accept new data from the client. It also ensures, client + is ready to accept new data before flushing data to it. + + Most importantly, BaseTcpServerHandler ensures that pending buffers + to the client are flushed before connection is closed. + + Implementations must provide:: + + a. handle_data(data: memoryview) implementation + b. Optionally, also implement other Work method + e.g. initialize, is_inactive, shutdown + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.must_flush_before_shutdown = False + logger.debug( + 'Work#%d accepted from %s', + self.work.connection.fileno(), + self.work.address, + ) + + def initialize(self) -> None: + """Optionally upgrades connection to HTTPS, + sets ``conn`` in non-blocking mode and initializes + HTTP protocol plugins.""" + conn = self._optionally_wrap_socket(self.work.connection) + conn.setblocking(False) + logger.debug('Handling connection %s' % self.work.address) + + @abstractmethod + def handle_data(self, data: memoryview) -> Optional[bool]: + """Optionally return True to close client connection.""" + pass # pragma: no cover + + async def get_events(self) -> SelectableEvents: + events = {} + # We always want to read from client + # Register for EVENT_READ events + if self.must_flush_before_shutdown is False: + events[self.work.connection.fileno()] = selectors.EVENT_READ + # If there is pending buffer for client + # also register for EVENT_WRITE events + if self.work.has_buffer(): + if self.work.connection.fileno() in events: + events[self.work.connection.fileno()] |= selectors.EVENT_WRITE + else: + events[self.work.connection.fileno()] = selectors.EVENT_WRITE + return events + + async def handle_events( + self, + readables: Readables, + writables: Writables, + ) -> bool: + """Return True to shutdown work.""" + teardown = await self.handle_writables( + writables, + ) or await self.handle_readables(readables) + if teardown: + logger.debug( + 'Shutting down client {0} connection'.format( + self.work.address, + ), + ) + return teardown + + async def handle_writables(self, writables: Writables) -> bool: + teardown = False + if self.work.connection.fileno() in writables and self.work.has_buffer(): + logger.debug( + 'Flushing buffer to client {0}'.format(self.work.address), + ) + self.work.flush(self.flags.max_sendbuf_size) + if self.must_flush_before_shutdown is True and \ + not self.work.has_buffer(): + teardown = True + self.must_flush_before_shutdown = False + return teardown + + async def handle_readables(self, readables: Readables) -> bool: + teardown = False + if self.work.connection.fileno() in readables: + try: + data = self.work.recv(self.flags.client_recvbuf_size) + except ConnectionResetError: + logger.info( + 'Connection reset by client {0}'.format( + self.work.address, + ), + ) + return True + except TimeoutError: + logger.info( + 'Client recv timeout error {0}'.format( + self.work.address, + ), + ) + return True + if data is None: + logger.debug( + 'Connection closed by client {0}'.format( + self.work.address, + ), + ) + teardown = True + else: + r = self.handle_data(data) + if isinstance(r, bool) and r is True: + logger.debug( + 'Implementation signaled shutdown for client {0}'.format( + self.work.address, + ), + ) + if self.work.has_buffer(): + logger.debug( + 'Client {0} has pending buffer, will be flushed before shutting down'.format( + self.work.address, + ), + ) + self.must_flush_before_shutdown = True + else: + teardown = True + return teardown + + def _encryption_enabled(self) -> bool: + return self.flags.keyfile is not None and \ + self.flags.certfile is not None + + def _optionally_wrap_socket(self, conn: socket.socket) -> TcpOrTlsSocket: + """Attempts to wrap accepted client connection using provided certificates. + + Shutdown and closes client connection upon error. + """ + if self._encryption_enabled(): + assert self.flags.keyfile and self.flags.certfile + # TODO(abhinavsingh): Insecure TLS versions must not be accepted by default + conn = wrap_socket(conn, self.flags.keyfile, self.flags.certfile) + self.work._conn = conn + return conn diff --git a/proxy/core/base/tcp_tunnel.py b/proxy/core/base/tcp_tunnel.py new file mode 100644 index 000000000..7b28fbec7 --- /dev/null +++ b/proxy/core/base/tcp_tunnel.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import logging +import selectors +from abc import abstractmethod +from typing import Any, Optional + +from .tcp_server import BaseTcpServerHandler +from ..connection import TcpClientConnection, TcpServerConnection +from ...http.parser import HttpParser, httpParserTypes +from ...common.types import Readables, Writables, SelectableEvents +from ...common.utils import text_ + + +logger = logging.getLogger(__name__) + + +class BaseTcpTunnelHandler(BaseTcpServerHandler[TcpClientConnection]): + """BaseTcpTunnelHandler build on-top of BaseTcpServerHandler work class. + + On-top of BaseTcpServerHandler implementation, + BaseTcpTunnelHandler introduces an upstream TcpServerConnection + and adds it to the core event loop when needed. + + Currently, implementations must call connect_upstream from within + handle_data. See HttpsConnectTunnelHandler for example usage. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.request = HttpParser( + httpParserTypes.REQUEST_PARSER, + enable_proxy_protocol=self.flags.enable_proxy_protocol, + ) + self.upstream: Optional[TcpServerConnection] = None + + @abstractmethod + def handle_data(self, data: memoryview) -> Optional[bool]: + pass # pragma: no cover + + @staticmethod + def create(*args: Any) -> TcpClientConnection: # pragma: no cover + return TcpClientConnection(*args) + + def initialize(self) -> None: + self.work.connection.setblocking(False) + + def shutdown(self) -> None: + if self.upstream: + logger.debug( + 'Connection closed with upstream {0}:{1}'.format( + text_(self.request.host), self.request.port, + ), + ) + self.upstream.close() + super().shutdown() + + async def get_events(self) -> SelectableEvents: + # Get default client events + ev: SelectableEvents = await super().get_events() + # Read from server if we are connected + if self.upstream and self.upstream._conn is not None: + ev[self.upstream.connection.fileno()] = selectors.EVENT_READ + # If there is pending buffer for server + # also register for EVENT_WRITE events + if self.upstream and self.upstream.has_buffer(): + if self.upstream.connection.fileno() in ev: + ev[self.upstream.connection.fileno()] |= selectors.EVENT_WRITE + else: + ev[self.upstream.connection.fileno()] = selectors.EVENT_WRITE + return ev + + async def handle_events( + self, + readables: Readables, + writables: Writables, + ) -> bool: + # Handle client events + do_shutdown: bool = await super().handle_events(readables, writables) + if do_shutdown: + return do_shutdown + # Handle server events + if self.upstream and self.upstream.connection.fileno() in readables: + data = self.upstream.recv(self.flags.server_recvbuf_size) + if data is None: + # Server closed connection + logger.debug('Connection closed by server') + return True + # tunnel data to client + self.work.queue(data) + if self.upstream and self.upstream.connection.fileno() in writables: + self.upstream.flush(self.flags.max_sendbuf_size) + return False + + def connect_upstream(self) -> None: + assert self.request.host and self.request.port + self.upstream = TcpServerConnection( + text_(self.request.host), self.request.port, + ) + self.upstream.connect() + logger.debug( + 'Connection established with upstream {0}:{1}'.format( + text_(self.request.host), self.request.port, + ), + ) diff --git a/proxy/core/base/tcp_upstream.py b/proxy/core/base/tcp_upstream.py new file mode 100644 index 000000000..31a065720 --- /dev/null +++ b/proxy/core/base/tcp_upstream.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import ssl +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional + +from ...common.types import Readables, Writables, Descriptors +from ...core.connection import TcpServerConnection + + +logger = logging.getLogger(__name__) + + +class TcpUpstreamConnectionHandler(ABC): + """:class:`~proxy.core.base.TcpUpstreamConnectionHandler` can + be used to insert an upstream server connection lifecycle. + + Call `initialize_upstream` to initialize the upstream connection object. + Then, directly use ``self.upstream`` object within your class. + + See :class:`~proxy.plugin.proxy_pool.ProxyPoolPlugin` for example usage. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + # This is currently a hack, see comments below for rationale, + # will be fixed later. + super().__init__(*args, **kwargs) # type: ignore + self.upstream: Optional[TcpServerConnection] = None + # TODO: Currently, :class:`~proxy.core.base.TcpUpstreamConnectionHandler` + # is used within :class:`~proxy.http.server.ReverseProxy` and + # :class:`~proxy.plugin.ProxyPoolPlugin`. + # + # For both of which we expect a 4-tuple as arguments + # containing (uuid, flags, client, event_queue). + # We really don't need the rest of the args here. + # May be uuid? May be event_queue in the future. + # But certainly we don't not client here. + # A separate tunnel class must be created which handles + # client connection too. + # + # Both :class:`~proxy.http.server.ReverseProxy` and + # :class:`~proxy.plugin.ProxyPoolPlugin` are currently + # calling client queue within `handle_upstream_data` callback. + # + # This can be abstracted out too. + self.server_recvbuf_size = args[1].server_recvbuf_size + self.total_size = 0 + + @abstractmethod + def handle_upstream_data(self, raw: memoryview) -> None: + raise NotImplementedError() # pragma: no cover + + def initialize_upstream(self, addr: str, port: int) -> None: + self.upstream = TcpServerConnection(addr, port) + + async def get_descriptors(self) -> Descriptors: + if not self.upstream: + return [], [] + return [self.upstream.connection.fileno()], \ + [self.upstream.connection.fileno()] \ + if self.upstream.has_buffer() \ + else [] + + async def read_from_descriptors(self, r: Readables) -> bool: + if self.upstream and \ + self.upstream.connection.fileno() in r: + try: + raw = self.upstream.recv(self.server_recvbuf_size) + if raw is None: # pragma: no cover + # Tear down because upstream proxy closed the connection + return True + self.total_size += len(raw) + self.handle_upstream_data(raw) + except TimeoutError: # pragma: no cover + logger.info('Upstream recv timeout error') + return True + except ssl.SSLWantReadError: # pragma: no cover + logger.info('Upstream SSLWantReadError, will retry') + return False + except ConnectionResetError: # pragma: no cover + logger.debug('Connection reset by upstream') + return True + return False + + async def write_to_descriptors(self, w: Writables) -> bool: + if self.upstream and \ + self.upstream.connection.fileno() in w and \ + self.upstream.has_buffer(): + try: + # TODO: max sendbuf size flag currently not used here + self.upstream.flush() + except ssl.SSLWantWriteError: # pragma: no cover + logger.info('Upstream SSLWantWriteError, will retry') + return False + except BrokenPipeError: # pragma: no cover + logger.debug('BrokenPipeError when flushing to upstream') + return True + return False diff --git a/proxy/core/connection/__init__.py b/proxy/core/connection/__init__.py index ee44bc14a..3457f1dbe 100644 --- a/proxy/core/connection/__init__.py +++ b/proxy/core/connection/__init__.py @@ -7,10 +7,18 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + reusability + Submodules """ -from .connection import TcpConnection, TcpConnectionUninitializedException, tcpConnectionTypes +from .pool import UpstreamConnectionPool +from .types import tcpConnectionTypes from .client import TcpClientConnection from .server import TcpServerConnection +from .connection import TcpConnection, TcpConnectionUninitializedException + __all__ = [ 'TcpConnection', @@ -18,4 +26,5 @@ 'TcpServerConnection', 'TcpClientConnection', 'tcpConnectionTypes', + 'UpstreamConnectionPool', ] diff --git a/proxy/core/connection/client.py b/proxy/core/connection/client.py index 28995a58a..f241c56a0 100644 --- a/proxy/core/connection/client.py +++ b/proxy/core/connection/client.py @@ -8,25 +8,45 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import socket import ssl -from typing import Union, Tuple, Optional +from typing import Optional -from .connection import TcpConnection, tcpConnectionTypes, TcpConnectionUninitializedException +from .types import tcpConnectionTypes +from .connection import TcpConnection, TcpConnectionUninitializedException +from ...common.types import HostPort, TcpOrTlsSocket class TcpClientConnection(TcpConnection): - """An accepted client connection request.""" + """A buffered client connection object.""" - def __init__(self, - conn: Union[ssl.SSLSocket, socket.socket], - addr: Tuple[str, int]): + def __init__( + self, + conn: TcpOrTlsSocket, + # optional for unix socket servers + addr: Optional[HostPort] = None, + ) -> None: super().__init__(tcpConnectionTypes.CLIENT) - self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn - self.addr: Tuple[str, int] = addr + self._conn: Optional[TcpOrTlsSocket] = conn + self.addr: Optional[HostPort] = addr @property - def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + def address(self) -> str: + return 'unix:client' if not self.addr else '{0}:{1}'.format(self.addr[0], self.addr[1]) + + @property + def connection(self) -> TcpOrTlsSocket: if self._conn is None: raise TcpConnectionUninitializedException() return self._conn + + def wrap(self, keyfile: str, certfile: str) -> None: + self.connection.setblocking(True) + self.flush() + self._conn = ssl.wrap_socket( + self.connection, + server_side=True, + certfile=certfile, + keyfile=keyfile, + ssl_version=ssl.PROTOCOL_TLS, + ) + self.connection.setblocking(False) diff --git a/proxy/core/connection/connection.py b/proxy/core/connection/connection.py index 3aa72eebc..d0bebe26d 100644 --- a/proxy/core/connection/connection.py +++ b/proxy/core/connection/connection.py @@ -8,22 +8,16 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import socket -import ssl import logging from abc import ABC, abstractmethod -from typing import NamedTuple, Optional, Union, List +from typing import List, Union, Optional +from .types import tcpConnectionTypes +from ...common.types import TcpOrTlsSocket from ...common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_MAX_SEND_SIZE -logger = logging.getLogger(__name__) - -TcpConnectionTypes = NamedTuple('TcpConnectionTypes', [ - ('SERVER', int), - ('CLIENT', int), -]) -tcpConnectionTypes = TcpConnectionTypes(1, 2) +logger = logging.getLogger(__name__) class TcpConnectionUninitializedException(Exception): @@ -37,32 +31,38 @@ class TcpConnection(ABC): when reading and writing into the socket. Implement the connection property abstract method to return - a socket connection object.""" + a socket connection object. + """ - def __init__(self, tag: int): + def __init__(self, tag: int) -> None: + self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client' self.buffer: List[memoryview] = [] self.closed: bool = False - self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client' + self._reusable: bool = False + self._num_buffer = 0 @property @abstractmethod - def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + def connection(self) -> TcpOrTlsSocket: """Must return the socket connection to use in this class.""" raise TcpConnectionUninitializedException() # pragma: no cover - def send(self, data: bytes) -> int: + def send(self, data: Union[memoryview, bytes]) -> int: """Users must handle BrokenPipeError exceptions""" + # logger.info(data.tobytes()) return self.connection.send(data) def recv( - self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[memoryview]: + self, buffer_size: int = DEFAULT_BUFFER_SIZE, + ) -> Optional[memoryview]: """Users must handle socket.error exceptions""" data: bytes = self.connection.recv(buffer_size) if len(data) == 0: return None logger.debug( 'received %d bytes from %s' % - (len(data), self.tag)) + (len(data), self.tag), + ) # logger.info(data) return memoryview(data) @@ -73,20 +73,38 @@ def close(self) -> bool: return self.closed def has_buffer(self) -> bool: - return len(self.buffer) > 0 + return self._num_buffer != 0 def queue(self, mv: memoryview) -> None: self.buffer.append(mv) + self._num_buffer += 1 - def flush(self) -> int: + def flush(self, max_send_size: Optional[int] = None) -> int: """Users must handle BrokenPipeError exceptions""" if not self.has_buffer(): return 0 - mv = self.buffer[0].tobytes() - sent: int = self.send(mv[:DEFAULT_MAX_SEND_SIZE]) + mv = self.buffer[0] + # TODO: Assemble multiple packets if total + # size remains below max send size. + max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE + sent: int = self.send(mv[:max_send_size]) if sent == len(mv): self.buffer.pop(0) + self._num_buffer -= 1 else: - self.buffer[0] = memoryview(mv[sent:]) + self.buffer[0] = mv[sent:] + del mv logger.debug('flushed %d bytes to %s' % (sent, self.tag)) return sent + + def is_reusable(self) -> bool: + return self._reusable + + def mark_inuse(self) -> None: + self._reusable = False + + def reset(self) -> None: + assert not self.closed + self._reusable = True + self.buffer = [] + self._num_buffer = 0 diff --git a/proxy/core/connection/pool.py b/proxy/core/connection/pool.py new file mode 100644 index 000000000..482dc193f --- /dev/null +++ b/proxy/core/connection/pool.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + reusability +""" +import socket +import logging +import selectors +from typing import TYPE_CHECKING, Any, Set, Dict, Tuple + +from ..work import Work +from .server import TcpServerConnection +from ...common.flag import flags +from ...common.types import HostPort, Readables, Writables, SelectableEvents + + +logger = logging.getLogger(__name__) + + +flags.add_argument( + '--enable-conn-pool', + action='store_true', + default=False, + help='Default: False. (WIP) Enable upstream connection pooling.', +) + + +class UpstreamConnectionPool(Work[TcpServerConnection]): + """Manages connection pool to upstream servers. + + `UpstreamConnectionPool` avoids need to reconnect with the upstream + servers repeatedly when a reusable connection is available + in the pool. + + A separate pool is maintained for each upstream server. + So internally, it's a pool of pools. + + Internal data structure maintains references to connection objects + that pool owns or has borrowed. Borrowed connections are marked as + NOT reusable. + + For reusable connections only, pool listens for read events + to detect broken connections. This can happen if pool has opened + a connection, which was never used and eventually reaches + upstream server timeout limit. + + When a borrowed connection is returned back to the pool, + the connection is marked as reusable again. However, if + returned connection has already been closed, it is removed + from the internal data structure. + + TODO: Ideally, `UpstreamConnectionPool` must be shared across + all cores to make SSL session cache to also work + without additional out-of-bound synchronizations. + + TODO: `UpstreamConnectionPool` currently WON'T work for + HTTPS connection. This is because of missing support for + session cache, session ticket, abbr TLS handshake + and other necessary features to make it work. + + NOTE: However, currently for all HTTP only upstream connections, + `UpstreamConnectionPool` can be used to remove slow starts. + """ + + def __init__(self) -> None: + self.connections: Dict[int, TcpServerConnection] = {} + self.pools: Dict[HostPort, Set[TcpServerConnection]] = {} + + @staticmethod + def create(*args: Any) -> TcpServerConnection: # pragma: no cover + return TcpServerConnection(*args) + + def acquire(self, addr: HostPort) -> Tuple[bool, TcpServerConnection]: + """Returns a reusable connection from the pool. + + If none exists, will create and return a new connection.""" + created, conn = False, None + if addr in self.pools: + for old_conn in self.pools[addr]: + if old_conn.is_reusable(): + conn = old_conn + logger.debug( + 'Reusing connection#{2} for upstream {0}:{1}'.format( + addr[0], addr[1], id(old_conn), + ), + ) + break + if conn is None: + created, conn = True, self.add(addr) + conn.mark_inuse() + return created, conn + + def release(self, conn: TcpServerConnection) -> None: + """Release a previously acquired connection. + + Releasing a connection will shutdown and close the socket + including internal pool cleanup. + """ + assert not conn.is_reusable() + logger.debug( + 'Removing connection#{2} from pool from upstream {0}:{1}'.format( + conn.addr[0], conn.addr[1], id(conn), + ), + ) + self._remove(conn.connection.fileno()) + + def retain(self, conn: TcpServerConnection) -> None: + """Retained previously acquired connection in the pool for reusability.""" + assert not conn.closed + logger.debug( + 'Retaining connection#{2} to upstream {0}:{1}'.format( + conn.addr[0], conn.addr[1], id(conn), + ), + ) + # Reset for reusability + conn.reset() + + async def get_events(self) -> SelectableEvents: + """Returns read event flag for all reusable connections in the pool.""" + events = {} + for connections in self.pools.values(): + for conn in connections: + if conn.is_reusable(): + events[conn.connection.fileno()] = selectors.EVENT_READ + return events + + async def handle_events(self, readables: Readables, _writables: Writables) -> bool: + """Removes reusable connection from the pool. + + When pool is the owner of connection, we don't expect a read event from upstream + server. A read event means either upstream closed the connection or connection + has somehow reached an illegal state e.g. upstream sending data for previous + connection acquisition lifecycle.""" + for fileno in readables: + if TYPE_CHECKING: # pragma: no cover + assert isinstance(fileno, int) + logger.debug('Upstream fd#{0} is read ready'.format(fileno)) + self._remove(fileno) + return False + + def add(self, addr: HostPort) -> TcpServerConnection: + """Creates, connects and adds a new connection to the pool. + + Returns newly created connection. + + NOTE: You must not use the returned connection, instead use `acquire`. + """ + new_conn = self.create(addr[0], addr[1]) + new_conn.connect() + self._add(new_conn) + logger.debug( + 'Created new connection#{2} for upstream {0}:{1}'.format( + addr[0], addr[1], id(new_conn), + ), + ) + return new_conn + + def _add(self, conn: TcpServerConnection) -> None: + """Adds a new connection to internal data structure.""" + if conn.addr not in self.pools: + self.pools[conn.addr] = set() + conn._reusable = True + self.pools[conn.addr].add(conn) + self.connections[conn.connection.fileno()] = conn + + def _remove(self, fileno: int) -> None: + """Remove a connection by descriptor from the internal data structure.""" + conn = self.connections[fileno] + logger.debug('Removing conn#{0} from pool'.format(id(conn))) + try: + conn.connection.shutdown(socket.SHUT_WR) + except OSError: + pass + conn.close() + self.pools[conn.addr].remove(conn) + del self.connections[fileno] diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index cbb9806a9..31233049f 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -8,29 +8,60 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import socket import ssl -from typing import Optional, Union, Tuple +from typing import Optional -from .connection import TcpConnection, tcpConnectionTypes, TcpConnectionUninitializedException +from .types import tcpConnectionTypes +from .connection import TcpConnection, TcpConnectionUninitializedException +from ...common.types import HostPort, TcpOrTlsSocket from ...common.utils import new_socket_connection class TcpServerConnection(TcpConnection): - """Establishes connection to upstream server.""" + """A buffered server connection object.""" - def __init__(self, host: str, port: int): + def __init__(self, host: str, port: int) -> None: super().__init__(tcpConnectionTypes.SERVER) - self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None - self.addr: Tuple[str, int] = (host, int(port)) + self._conn: Optional[TcpOrTlsSocket] = None + self.addr: HostPort = (host, port) + self.closed = True @property - def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + def connection(self) -> TcpOrTlsSocket: if self._conn is None: raise TcpConnectionUninitializedException() return self._conn - def connect(self) -> None: - if self._conn is not None: - return - self._conn = new_socket_connection(self.addr) + def connect( + self, + addr: Optional[HostPort] = None, + source_address: Optional[HostPort] = None, + ) -> None: + assert self._conn is None + self._conn = new_socket_connection( + addr or self.addr, source_address=source_address, + ) + self.closed = False + + def wrap( + self, + hostname: Optional[str] = None, + ca_file: Optional[str] = None, + as_non_blocking: bool = False, + # Ref https://github.com/PyCQA/pylint/issues/3691 + verify_mode: ssl.VerifyMode = ssl.VerifyMode.CERT_REQUIRED, # pylint: disable=E1101 + ) -> None: + ctx = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH, + cafile=ca_file, + ) + ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + ctx.check_hostname = hostname is not None + ctx.verify_mode = verify_mode + self.connection.setblocking(True) + self._conn = ctx.wrap_socket( + self.connection, + server_hostname=hostname, + ) + if as_non_blocking: + self.connection.setblocking(False) diff --git a/proxy/core/connection/types.py b/proxy/core/connection/types.py new file mode 100644 index 000000000..44522c81b --- /dev/null +++ b/proxy/core/connection/types.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import NamedTuple + + +TcpConnectionTypes = NamedTuple( + 'TcpConnectionTypes', [ + ('SERVER', int), + ('CLIENT', int), + ], +) + +tcpConnectionTypes = TcpConnectionTypes(1, 2) diff --git a/proxy/core/event/__init__.py b/proxy/core/event/__init__.py index 6907dcd55..05736f7b5 100644 --- a/proxy/core/event/__init__.py +++ b/proxy/core/event/__init__.py @@ -8,15 +8,18 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from .queue import EventQueue from .names import EventNames, eventNames +from .queue import EventQueue +from .manager import EventManager from .dispatcher import EventDispatcher from .subscriber import EventSubscriber + __all__ = [ 'eventNames', 'EventNames', 'EventQueue', 'EventDispatcher', 'EventSubscriber', + 'EventManager', ] diff --git a/proxy/core/event/dispatcher.py b/proxy/core/event/dispatcher.py index f6bb849e5..c9be44b76 100644 --- a/proxy/core/event/dispatcher.py +++ b/proxy/core/event/dispatcher.py @@ -9,15 +9,14 @@ :license: BSD, see LICENSE for more details. """ import queue -import threading import logging +import threading +from typing import Any, Dict, List +from multiprocessing import connection -from typing import Dict, Any, List - -from ...common.types import DictQueueType - -from .queue import EventQueue from .names import eventNames +from .queue import EventQueue + logger = logging.getLogger(__name__) @@ -25,52 +24,64 @@ class EventDispatcher: """Core EventDispatcher. - Provides: - 1. A dispatcher module which consumes core events and dispatches - them to EventQueueBasePlugin - 2. A publish utility for publishing core events into - global events queue. - Direct consuming from global events queue outside of dispatcher module is not-recommended. Python native multiprocessing queue doesn't provide a fanout functionality which core dispatcher module - implements so that several plugins can consume same published - event at a time. + implements so that several plugins can consume the same published + event concurrently (when necessary). When --enable-events is used, a multiprocessing.Queue is created and - attached to global Flags. This queue can then be used for + attached to global flags. This queue can then be used for dispatching an Event dict object into the queue. When --enable-events is used, dispatcher module is automatically - started. Dispatcher module also ensures that queue is not full and - doesn't utilize too much memory in case there are no event plugins - enabled. + started. Most importantly, dispatcher module ensures that queue is + not flooded and doesn't utilize too much memory in case there are no + event subscribers for published messages. + + EventDispatcher ensures that subscribers will receive the messages + in the order they are published. """ def __init__( self, shutdown: threading.Event, - event_queue: EventQueue) -> None: + event_queue: EventQueue, + ) -> None: self.shutdown: threading.Event = shutdown self.event_queue: EventQueue = event_queue - self.subscribers: Dict[str, DictQueueType] = {} + # subscriber connection objects + self.subscribers: Dict[str, connection.Connection] = {} def handle_event(self, ev: Dict[str, Any]) -> None: if ev['event_name'] == eventNames.SUBSCRIBE: - self.subscribers[ev['event_payload']['sub_id']] = \ - ev['event_payload']['channel'] + sub_id = ev['event_payload']['sub_id'] + self.subscribers[sub_id] = ev['event_payload']['conn'] + # send ack + if not self._send( + sub_id, { + 'event_name': eventNames.SUBSCRIBED, + }, + ): + self._close_and_delete(sub_id) elif ev['event_name'] == eventNames.UNSUBSCRIBE: - del self.subscribers[ev['event_payload']['sub_id']] + sub_id = ev['event_payload']['sub_id'] + if sub_id in self.subscribers: + # send ack + logger.debug('unsubscription request ack sent') + self._send( + sub_id, { + 'event_name': eventNames.UNSUBSCRIBED, + }, + ) + self._close_and_delete(sub_id) + else: + logger.info( + 'unsubscription request ack not sent, subscriber already gone', + ) else: # logger.info(ev) - unsub_ids: List[str] = [] - for sub_id in self.subscribers: - try: - self.subscribers[sub_id].put(ev) - except BrokenPipeError: - unsub_ids.append(sub_id) - for sub_id in unsub_ids: - del self.subscribers[sub_id] + self._broadcast(ev) def run_once(self) -> None: ev: Dict[str, Any] = self.event_queue.queue.get(timeout=1) @@ -83,9 +94,46 @@ def run(self) -> None: self.run_once() except queue.Empty: pass - except EOFError: - pass except KeyboardInterrupt: pass except Exception as e: - logger.exception('Event dispatcher exception', exc_info=e) + logger.exception('Dispatcher exception', exc_info=e) + finally: + # Send shutdown message to all active subscribers + self._broadcast({ + 'event_name': eventNames.DISPATCHER_SHUTDOWN, + }) + logger.info('Dispatcher shutdown') + + def _broadcast(self, ev: Dict[str, Any]) -> None: + broken_pipes: List[str] = [] + for sub_id in self.subscribers: + try: + self.subscribers[sub_id].send(ev) + except BrokenPipeError: + logger.warning( + 'Subscriber#%s broken pipe', sub_id, + ) + self._close(sub_id) + broken_pipes.append(sub_id) + for sub_id in broken_pipes: + del self.subscribers[sub_id] + + def _close_and_delete(self, sub_id: str) -> None: + self._close(sub_id) + del self.subscribers[sub_id] + + def _close(self, sub_id: str) -> None: + try: + self.subscribers[sub_id].close() + except Exception: # noqa: S110 + pass + + def _send(self, sub_id: str, payload: Any) -> bool: + done = False + try: + self.subscribers[sub_id].send(payload) + done = True + except (BrokenPipeError, EOFError): + pass + return done diff --git a/proxy/core/event/manager.py b/proxy/core/event/manager.py new file mode 100644 index 000000000..6a100f904 --- /dev/null +++ b/proxy/core/event/manager.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + eventing +""" +import logging +import threading +import multiprocessing +from typing import Any, Optional + +from .queue import EventQueue +from .dispatcher import EventDispatcher +from ...common.flag import flags +from ...common.constants import DEFAULT_ENABLE_EVENTS + + +logger = logging.getLogger(__name__) + + +flags.add_argument( + '--enable-events', + action='store_true', + default=DEFAULT_ENABLE_EVENTS, + help='Default: False. Enables core to dispatch lifecycle events. ' + 'Plugins can be used to subscribe for core events.', +) + + +class EventManager: + """Event manager is a context manager which provides + encapsulation around various setup and shutdown steps + to start the eventing core. + """ + + def __init__(self) -> None: + self.queue: Optional[EventQueue] = None + self.dispatcher: Optional[EventDispatcher] = None + self.dispatcher_thread: Optional[threading.Thread] = None + self.dispatcher_shutdown: Optional[threading.Event] = None + + def __enter__(self) -> 'EventManager': + self.setup() + return self + + def __exit__(self, *args: Any) -> None: + self.shutdown() + + def setup(self) -> None: + self.queue = EventQueue(multiprocessing.Queue()) + self.dispatcher_shutdown = threading.Event() + assert self.dispatcher_shutdown + assert self.queue + self.dispatcher = EventDispatcher( + shutdown=self.dispatcher_shutdown, + event_queue=self.queue, + ) + self.dispatcher_thread = threading.Thread( + target=self.dispatcher.run, + ) + self.dispatcher_thread.start() + logger.debug('Dispatcher#%d started', self.dispatcher_thread.ident) + + def shutdown(self) -> None: + assert self.dispatcher_shutdown and self.dispatcher_thread + self.dispatcher_shutdown.set() + self.dispatcher_thread.join() + logger.debug( + 'Dispatcher#%d shutdown', + self.dispatcher_thread.ident, + ) diff --git a/proxy/core/event/names.py b/proxy/core/event/names.py index b45a70b2d..369724aac 100644 --- a/proxy/core/event/names.py +++ b/proxy/core/event/names.py @@ -10,14 +10,24 @@ """ from typing import NamedTuple -EventNames = NamedTuple('EventNames', [ - ('SUBSCRIBE', int), - ('UNSUBSCRIBE', int), - ('WORK_STARTED', int), - ('WORK_FINISHED', int), - ('REQUEST_COMPLETE', int), - ('RESPONSE_HEADERS_COMPLETE', int), - ('RESPONSE_CHUNK_RECEIVED', int), - ('RESPONSE_COMPLETE', int), -]) -eventNames = EventNames(1, 2, 3, 4, 5, 6, 7, 8) + +# Name of the events that eventing framework supports. +# +# Ideally this must be configurable via command line or +# at-least extendable via plugins. +EventNames = NamedTuple( + 'EventNames', [ + ('SUBSCRIBE', int), + ('SUBSCRIBED', int), + ('UNSUBSCRIBE', int), + ('UNSUBSCRIBED', int), + ('DISPATCHER_SHUTDOWN', int), + ('WORK_STARTED', int), + ('WORK_FINISHED', int), + ('REQUEST_COMPLETE', int), + ('RESPONSE_HEADERS_COMPLETE', int), + ('RESPONSE_CHUNK_RECEIVED', int), + ('RESPONSE_COMPLETE', int), + ], +) +eventNames = EventNames(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) diff --git a/proxy/core/event/queue.py b/proxy/core/event/queue.py index 36b246648..878fe9bf4 100644 --- a/proxy/core/event/queue.py +++ b/proxy/core/event/queue.py @@ -9,33 +9,36 @@ :license: BSD, see LICENSE for more details. """ import os -import threading import time -from typing import Dict, Optional, Any - -from ...common.types import DictQueueType +import threading +from typing import Any, Dict, Optional +from multiprocessing import connection from .names import eventNames +from ...common.types import DictQueueType class EventQueue: - """Global event queue. + """Global event queue. Must be a multiprocess safe queue capable of + transporting other queues. This is necessary because currently + subscribers use a separate subscription queue to consume events. + Subscription queue is exchanged over the global event queue. + + Each published event contains following schema:: - Each event contains: + { + 'request_id': 'Globally unique request ID', + 'process_id': 'Process ID of event publisher. This ' + 'will be the process ID of acceptor workers.', + 'thread_id': 'Thread ID of event publisher. ' + 'When --threadless is enabled, this value ' + 'will be same for all the requests.' + 'event_timestamp': 'Time when this event occured', + 'event_name': 'one of the pre-defined or custom event name', + 'event_payload': 'Optional data associated with the event', + 'publisher_id': 'Optional publisher entity unique name', + } - 1. Request ID - Globally unique - 2. Process ID - Process ID of event publisher. - This will be process id of acceptor workers. - 3. Thread ID - Thread ID of event publisher. - When --threadless is enabled, this value will - be same for all the requests - received by a single acceptor worker. - When --threadless is disabled, this value will be - Thread ID of the thread handling the client request. - 4. Event Timestamp - Time when this event occur - 5. Event Name - One of the defined or custom event name - 6. Event Payload - Optional data associated with the event - 7. Publisher ID (optional) - Optionally, publishing entity unique name / ID """ def __init__(self, queue: DictQueueType) -> None: @@ -46,13 +49,13 @@ def publish( request_id: str, event_name: int, event_payload: Dict[str, Any], - publisher_id: Optional[str] = None + publisher_id: Optional[str] = None, ) -> None: self.queue.put({ - 'request_id': request_id, 'process_id': os.getpid(), 'thread_id': threading.get_ident(), 'event_timestamp': time.time(), + 'request_id': request_id, 'event_name': event_name, 'event_payload': event_payload, 'publisher_id': publisher_id, @@ -61,16 +64,22 @@ def publish( def subscribe( self, sub_id: str, - channel: DictQueueType) -> None: - """Subscribe to global events.""" + channel: connection.Connection, + ) -> None: + """Subscribe to global events. + + sub_id is a subscription identifier which must be globally + unique. channel MUST be a multiprocessing connection. + """ self.queue.put({ 'event_name': eventNames.SUBSCRIBE, - 'event_payload': {'sub_id': sub_id, 'channel': channel}, + 'event_payload': {'sub_id': sub_id, 'conn': channel}, }) def unsubscribe( self, - sub_id: str) -> None: + sub_id: str, + ) -> None: """Unsubscribe by subscriber id.""" self.queue.put({ 'event_name': eventNames.UNSUBSCRIBE, diff --git a/proxy/core/event/subscriber.py b/proxy/core/event/subscriber.py index ec6afe623..e9ce1b60d 100644 --- a/proxy/core/event/subscriber.py +++ b/proxy/core/event/subscriber.py @@ -8,79 +8,171 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import uuid import queue +import logging import threading import multiprocessing -import logging -import uuid - -from typing import Dict, Optional, Any, Callable - -from ...common.types import DictQueueType +from typing import Any, Dict, Callable, Optional +from multiprocessing import connection +from .names import eventNames from .queue import EventQueue + logger = logging.getLogger(__name__) class EventSubscriber: - """Core event subscriber.""" + """Core event subscriber. + + Usage: Initialize one instance per CPU core for optimum performance. + + EventSubscriber can run within various context. E.g. main thread, + another thread or a different process. EventSubscriber context + can be different from publishers. Publishers can even be processes + outside of the proxy.py core. + + `multiprocessing.Pipe` is used to initialize a new Queue for + receiving subscribed events from eventing core. Note that, + core EventDispatcher might be running in a separate process + and hence subscription queue must be multiprocess safe. + + When `subscribe` method is called, EventManager stars + a relay thread which consumes event out of the subscription queue + and invoke callback. - def __init__(self, event_queue: EventQueue) -> None: - self.manager = multiprocessing.Manager() + NOTE: Callback is executed in the context of relay thread. + """ + + def __init__(self, event_queue: EventQueue, callback: Callable[[Dict[str, Any]], None]) -> None: self.event_queue = event_queue + self.callback = callback self.relay_thread: Optional[threading.Thread] = None self.relay_shutdown: Optional[threading.Event] = None - self.relay_channel: Optional[DictQueueType] = None + self.relay_recv: Optional[connection.Connection] = None + self.relay_send: Optional[connection.Connection] = None self.relay_sub_id: Optional[str] = None - def subscribe(self, callback: Callable[[Dict[str, Any]], None]) -> None: - self.relay_shutdown = threading.Event() - self.relay_channel = self.manager.Queue() - self.relay_thread = threading.Thread( - target=self.relay, - args=(self.relay_shutdown, self.relay_channel, callback)) - self.relay_thread.start() - self.relay_sub_id = uuid.uuid4().hex - self.event_queue.subscribe(self.relay_sub_id, self.relay_channel) - logger.debug( - 'Subscribed relay sub id %s from core events', - self.relay_sub_id) + def __enter__(self) -> 'EventSubscriber': + self.setup() + return self - def unsubscribe(self) -> None: - if self.relay_sub_id is None: - logger.warning('Unsubscribe called without existing subscription') - return + def __exit__(self, *args: Any) -> None: + self.shutdown() - assert self.relay_thread - assert self.relay_shutdown - assert self.relay_channel - assert self.relay_sub_id + def setup(self, do_subscribe: bool = True) -> None: + """Setup subscription thread. - self.event_queue.unsubscribe(self.relay_sub_id) - self.relay_shutdown.set() - self.relay_thread.join() + Call subscribe() to actually start subscription. + """ + self._start_relay_thread() + assert self.relay_sub_id and self.relay_recv logger.debug( - 'Un-subscribed relay sub id %s from core events', - self.relay_sub_id) + 'Subscriber#%s relay setup done', + self.relay_sub_id, + ) + if do_subscribe: + self.subscribe() - self.relay_thread = None - self.relay_shutdown = None - self.relay_channel = None - self.relay_sub_id = None + def shutdown(self, do_unsubscribe: bool = True) -> None: + """Tear down subscription thread. + + Call unsubscribe() to actually stop subscription. + """ + self._stop_relay_thread() + logger.debug( + 'Subscriber#%s relay shutdown done', + self.relay_sub_id, + ) + if do_unsubscribe: + self.unsubscribe() + + def subscribe(self) -> None: + assert self.relay_sub_id and self.relay_send + self.event_queue.subscribe(self.relay_sub_id, self.relay_send) + + def unsubscribe(self) -> None: + if self.relay_sub_id is None: + logger.warning( + 'Relay called unsubscribe without an active subscription', + ) + return + try: + self.event_queue.unsubscribe(self.relay_sub_id) + except (BrokenPipeError, EOFError): + pass + finally: + # self.relay_sub_id = None + pass @staticmethod def relay( + sub_id: str, shutdown: threading.Event, - channel: DictQueueType, - callback: Callable[[Dict[str, Any]], None]) -> None: + channel: connection.Connection, + callback: Callable[[Dict[str, Any]], None], + ) -> None: while not shutdown.is_set(): try: - ev = channel.get(timeout=1) - callback(ev) + if channel.poll(timeout=1): + ev = channel.recv() + if ev['event_name'] == eventNames.SUBSCRIBED: + logger.info( + 'Subscriber#{0} subscribe ack received'.format( + sub_id, + ), + ) + elif ev['event_name'] == eventNames.UNSUBSCRIBED: + logger.info( + 'Subscriber#{0} unsubscribe ack received'.format( + sub_id, + ), + ) + break + elif ev['event_name'] == eventNames.DISPATCHER_SHUTDOWN: + logger.info( + 'Subscriber#{0} received dispatcher shutdown event'.format( + sub_id, + ), + ) + break + else: + callback(ev) except queue.Empty: pass except EOFError: break except KeyboardInterrupt: break + logger.debug('bbye!!!') + + def _start_relay_thread(self) -> None: + self.relay_sub_id = uuid.uuid4().hex + self.relay_shutdown = threading.Event() + self.relay_recv, self.relay_send = multiprocessing.Pipe() + self.relay_thread = threading.Thread( + target=EventSubscriber.relay, + args=( + self.relay_sub_id, self.relay_shutdown, + self.relay_recv, self.callback, + ), + ) + self.relay_thread.daemon = True + self.relay_thread.start() + + def _stop_relay_thread(self) -> None: + assert self.relay_thread and self.relay_shutdown and self.relay_recv and self.relay_send + self.relay_shutdown.set() + self.relay_thread.join() + self.relay_recv.close() + # Currently relay_send instance here in + # subscriber is not the same as one received + # by dispatcher. This may cause file + # descriptor leakage. So we make a close + # here explicit on our side of relay_send too. + self.relay_send.close() + self.relay_thread = None + self.relay_shutdown = None + self.relay_recv = None + self.relay_send = None diff --git a/proxy/core/listener/__init__.py b/proxy/core/listener/__init__.py new file mode 100644 index 000000000..d2c89e72a --- /dev/null +++ b/proxy/core/listener/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + pre +""" +from .tcp import TcpSocketListener +from .pool import ListenerPool +from .unix import UnixSocketListener + + +__all__ = [ + 'UnixSocketListener', + 'TcpSocketListener', + 'ListenerPool', +] diff --git a/proxy/core/listener/base.py b/proxy/core/listener/base.py new file mode 100644 index 000000000..49357b334 --- /dev/null +++ b/proxy/core/listener/base.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import socket +import logging +import argparse +from abc import ABC, abstractmethod +from typing import Any, Optional + +from ...common.flag import flags +from ...common.constants import DEFAULT_BACKLOG + + +flags.add_argument( + '--backlog', + type=int, + default=DEFAULT_BACKLOG, + help='Default: 100. Maximum number of pending connections to proxy server.', +) + +logger = logging.getLogger(__name__) + + +class BaseListener(ABC): + """Base listener class. + + For usage provide a listen method implementation.""" + + def __init__(self, *args: Any, flags: argparse.Namespace, **kwargs: Any) -> None: + self.flags = flags + self._socket: Optional[socket.socket] = None + + @abstractmethod + def listen(self) -> socket.socket: + raise NotImplementedError() + + def __enter__(self) -> 'BaseListener': + self.setup() + return self + + def __exit__(self, *args: Any) -> None: + self.shutdown() + + def fileno(self) -> Optional[int]: + if not self._socket: + return None + return self._socket.fileno() + + def setup(self) -> None: + self._socket = self.listen() + + def shutdown(self) -> None: + assert self._socket + self._socket.close() diff --git a/proxy/core/listener/pool.py b/proxy/core/listener/pool.py new file mode 100644 index 000000000..b362ae558 --- /dev/null +++ b/proxy/core/listener/pool.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import argparse +from typing import TYPE_CHECKING, Any, List, Type + +from .tcp import TcpSocketListener +from .unix import UnixSocketListener + + +if TYPE_CHECKING: # pragma: no cover + from .base import BaseListener + + +class ListenerPool: + """Provides abstraction around starting multiple listeners + based upon flags.""" + + def __init__(self, flags: argparse.Namespace) -> None: + self.flags = flags + self.pool: List['BaseListener'] = [] + + def __enter__(self) -> 'ListenerPool': + self.setup() + return self + + def __exit__(self, *args: Any) -> None: + self.shutdown() + + def setup(self) -> None: + if self.flags.unix_socket_path: + self.add(UnixSocketListener) + else: + self.add(TcpSocketListener) + for port in self.flags.ports: + self.add(TcpSocketListener, port=port) + + def shutdown(self) -> None: + for listener in self.pool: + listener.shutdown() + self.pool.clear() + + def add(self, klass: Type['BaseListener'], **kwargs: Any) -> None: + listener = klass(flags=self.flags, **kwargs) + listener.setup() + self.pool.append(listener) diff --git a/proxy/core/listener/tcp.py b/proxy/core/listener/tcp.py new file mode 100644 index 000000000..f841183fd --- /dev/null +++ b/proxy/core/listener/tcp.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import socket +import logging +from typing import Any, Optional + +from .base import BaseListener +from ...common.flag import flags +from ...common.constants import ( + DEFAULT_PORT, DEFAULT_PORT_FILE, DEFAULT_IPV4_HOSTNAME, +) + + +flags.add_argument( + '--hostname', + type=str, + default=str(DEFAULT_IPV4_HOSTNAME), + help='Default: 127.0.0.1. Server IP address.', +) + +flags.add_argument( + '--port', + type=int, + default=DEFAULT_PORT, + help='Default: 8899. Server port. To listen on more ports, pass them using --ports flag.', +) + +flags.add_argument( + '--ports', + action='append', + nargs='+', + default=None, + help='Default: None. Additional ports to listen on.', +) + +flags.add_argument( + '--port-file', + type=str, + default=DEFAULT_PORT_FILE, + help='Default: None. Save server port numbers. Useful when using --port=0 ephemeral mode.', +) + +logger = logging.getLogger(__name__) + + +class TcpSocketListener(BaseListener): + """Tcp listener.""" + + def __init__(self, *args: Any, port: Optional[int] = None, **kwargs: Any) -> None: + # Port if passed will be used, otherwise + # flag port value will be used. + self.port = port + # Set after binding to a port. + # + # Stored here separately for ephemeral port discovery. + self._port: Optional[int] = None + super().__init__(*args, **kwargs) + + def listen(self) -> socket.socket: + sock = socket.socket( + socket.AF_INET6 if self.flags.hostname.version == 6 else socket.AF_INET, + socket.SOCK_STREAM, + ) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # s.setsockopt(socket.SOL_TCP, socket.TCP_FASTOPEN, 5) + port = self.port if self.port is not None else self.flags.port + sock.bind((str(self.flags.hostname), port)) + sock.listen(self.flags.backlog) + sock.setblocking(False) + self._port = sock.getsockname()[1] + logger.info( + 'Listening on %s:%s' % + (self.flags.hostname, self._port), + ) + return sock diff --git a/proxy/core/listener/unix.py b/proxy/core/listener/unix.py new file mode 100644 index 000000000..4defd9a18 --- /dev/null +++ b/proxy/core/listener/unix.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import os +import socket +import logging + +from .base import BaseListener +from ...common.flag import flags + + +flags.add_argument( + '--unix-socket-path', + type=str, + default=None, + help='Default: None. Unix socket path to use. ' + + 'When provided --host and --port flags are ignored', +) + +logger = logging.getLogger(__name__) + + +class UnixSocketListener(BaseListener): + """Unix socket domain listener.""" + + def listen(self) -> socket.socket: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(self.flags.unix_socket_path) + sock.listen(self.flags.backlog) + sock.setblocking(False) + logger.info( + 'Listening on %s' % + self.flags.unix_socket_path, + ) + return sock + + def shutdown(self) -> None: + super().shutdown() + if self.flags.unix_socket_path: + os.remove(self.flags.unix_socket_path) diff --git a/proxy/core/ssh/__init__.py b/proxy/core/ssh/__init__.py index 232621f0b..9d9d605de 100644 --- a/proxy/core/ssh/__init__.py +++ b/proxy/core/ssh/__init__.py @@ -7,4 +7,16 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + Submodules """ +from .handler import SshHttpProtocolHandler +from .listener import SshTunnelListener + + +__all__ = [ + 'SshHttpProtocolHandler', + 'SshTunnelListener', +] diff --git a/proxy/core/ssh/client.py b/proxy/core/ssh/client.py deleted file mode 100644 index 650d89480..000000000 --- a/proxy/core/ssh/client.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import socket -import ssl -from typing import Union - -from ..connection import TcpClientConnection - - -class SshClient(TcpClientConnection): - """Overrides TcpClientConnection. - - This is necessary because paramiko fileno() can be used for polling - but not for send / recv. - """ - - @property - def connection(self) -> Union[ssl.SSLSocket, socket.socket]: - # Dummy return to comply with - return socket.socket() diff --git a/proxy/core/ssh/handler.py b/proxy/core/ssh/handler.py new file mode 100644 index 000000000..ed6ea789f --- /dev/null +++ b/proxy/core/ssh/handler.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import argparse +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: # pragma: no cover + from ...common.types import HostPort + try: + from paramiko.channel import Channel + except ImportError: + pass + + +class SshHttpProtocolHandler: + """Handles incoming connections over forwarded SSH transport.""" + + def __init__(self, flags: argparse.Namespace) -> None: + self.flags = flags + + def on_connection( + self, + chan: 'Channel', + origin: 'HostPort', + server: 'HostPort', + ) -> None: + pass diff --git a/proxy/core/ssh/listener.py b/proxy/core/ssh/listener.py new file mode 100644 index 000000000..d851600fd --- /dev/null +++ b/proxy/core/ssh/listener.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import logging +import argparse +from typing import TYPE_CHECKING, Any, Set, Callable, Optional + + +try: + from paramiko import SSHClient, AutoAddPolicy + from paramiko.transport import Transport + if TYPE_CHECKING: # pragma: no cover + from paramiko.channel import Channel + + from ...common.types import HostPort +except ImportError: # pragma: no cover + pass + +from ...common.flag import flags + + +logger = logging.getLogger(__name__) + + +flags.add_argument( + '--tunnel-hostname', + type=str, + default=None, + help='Default: None. Remote hostname or IP address to which SSH tunnel will be established.', +) + +flags.add_argument( + '--tunnel-port', + type=int, + default=22, + help='Default: 22. SSH port of the remote host.', +) + +flags.add_argument( + '--tunnel-username', + type=str, + default=None, + help='Default: None. Username to use for establishing SSH tunnel.', +) + +flags.add_argument( + '--tunnel-ssh-key', + type=str, + default=None, + help='Default: None. Private key path in pem format', +) + +flags.add_argument( + '--tunnel-ssh-key-passphrase', + type=str, + default=None, + help='Default: None. Private key passphrase', +) + +flags.add_argument( + '--tunnel-remote-port', + type=int, + default=8899, + help='Default: 8899. Remote port which will be forwarded locally for proxy.', +) + + +class SshTunnelListener: + """Connects over SSH and forwards a remote port to local host. + + Incoming connections are delegated to provided callback.""" + + def __init__( + self, + flags: argparse.Namespace, + on_connection_callback: Callable[['Channel', 'HostPort', 'HostPort'], None], + ) -> None: + self.flags = flags + self.on_connection_callback = on_connection_callback + self.ssh: Optional[SSHClient] = None + self.transport: Optional[Transport] = None + self.forwarded: Set['HostPort'] = set() + + def start_port_forward(self, remote_addr: 'HostPort') -> None: + assert self.transport is not None + self.transport.request_port_forward( + *remote_addr, + handler=self.on_connection_callback, + ) + self.forwarded.add(remote_addr) + logger.info('%s:%d forwarding successful...' % remote_addr) + + def stop_port_forward(self, remote_addr: 'HostPort') -> None: + assert self.transport is not None + self.transport.cancel_port_forward(*remote_addr) + self.forwarded.remove(remote_addr) + + def __enter__(self) -> 'SshTunnelListener': + self.setup() + return self + + def __exit__(self, *args: Any) -> None: + self.shutdown() + + def setup(self) -> None: + self.ssh = SSHClient() + self.ssh.load_system_host_keys() + self.ssh.set_missing_host_key_policy(AutoAddPolicy()) + self.ssh.connect( + hostname=self.flags.tunnel_hostname, + port=self.flags.tunnel_port, + username=self.flags.tunnel_username, + key_filename=self.flags.tunnel_ssh_key, + passphrase=self.flags.tunnel_ssh_key_passphrase, + ) + logger.info( + 'SSH connection established to %s:%d...' % ( + self.flags.tunnel_hostname, + self.flags.tunnel_port, + ), + ) + self.transport = self.ssh.get_transport() + + def shutdown(self) -> None: + for remote_addr in list(self.forwarded): + self.stop_port_forward(remote_addr) + self.forwarded.clear() + if self.transport is not None: + self.transport.close() + if self.ssh is not None: + self.ssh.close() diff --git a/proxy/core/ssh/tunnel.py b/proxy/core/ssh/tunnel.py deleted file mode 100644 index e3a61b54d..000000000 --- a/proxy/core/ssh/tunnel.py +++ /dev/null @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from typing import Tuple, Callable - -import paramiko - - -class Tunnel: - """Establishes a tunnel between local (machine where Tunnel is running) and remote host. - Once a tunnel has been established, remote host can route HTTP(s) traffic to - localhost over tunnel. - """ - - def __init__( - self, - ssh_username: str, - remote_addr: Tuple[str, int], - private_pem_key: str, - remote_proxy_port: int, - conn_handler: Callable[[paramiko.channel.Channel], None]) -> None: - self.remote_addr = remote_addr - self.ssh_username = ssh_username - self.private_pem_key = private_pem_key - self.remote_proxy_port = remote_proxy_port - self.conn_handler = conn_handler - - def run(self) -> None: - ssh = paramiko.SSHClient() - ssh.load_system_host_keys() - ssh.set_missing_host_key_policy(paramiko.WarningPolicy()) - try: - ssh.connect( - hostname=self.remote_addr[0], - port=self.remote_addr[1], - username=self.ssh_username, - key_filename=self.private_pem_key - ) - print('SSH connection established...') - transport: paramiko.transport.Transport = ssh.get_transport() - transport.request_port_forward('', self.remote_proxy_port) - print('Tunnel port forward setup successful...') - while True: - conn: paramiko.channel.Channel = transport.accept(timeout=1) - e = transport.get_exception() - if e: - raise e - if conn is None: - continue - self.conn_handler(conn) - except KeyboardInterrupt: - pass - finally: - ssh.close() diff --git a/proxy/core/threadless.py b/proxy/core/threadless.py deleted file mode 100644 index 87be7e5b8..000000000 --- a/proxy/core/threadless.py +++ /dev/null @@ -1,247 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import os -import socket -import logging -import asyncio -import selectors -import contextlib -import multiprocessing -from multiprocessing import connection -from multiprocessing.reduction import recv_handle - -from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple, List, Union, Generator, Any, Type -from uuid import uuid4, UUID - -from .connection import TcpClientConnection -from .event import EventQueue, eventNames - -from ..common.flags import Flags -from ..common.types import HasFileno -from ..common.constants import DEFAULT_TIMEOUT - -logger = logging.getLogger(__name__) - - -class ThreadlessWork(ABC): - """Implement ThreadlessWork to hook into the event loop provided by Threadless process.""" - - @abstractmethod - def __init__( - self, - client: TcpClientConnection, - flags: Optional[Flags], - event_queue: Optional[EventQueue] = None, - uid: Optional[UUID] = None) -> None: - self.client = client - self.flags = flags if flags else Flags() - self.event_queue = event_queue - self.uid: UUID = uid if uid is not None else uuid4() - - @abstractmethod - def initialize(self) -> None: - pass # pragma: no cover - - @abstractmethod - def is_inactive(self) -> bool: - return False # pragma: no cover - - @abstractmethod - def get_events(self) -> Dict[socket.socket, int]: - return {} # pragma: no cover - - @abstractmethod - def handle_events( - self, - readables: List[Union[int, HasFileno]], - writables: List[Union[int, HasFileno]]) -> bool: - """Return True to shutdown work.""" - return False # pragma: no cover - - @abstractmethod - def run(self) -> None: - pass - - def publish_event( - self, - event_name: int, - event_payload: Dict[str, Any], - publisher_id: Optional[str] = None) -> None: - if not self.flags.enable_events: - return - assert self.event_queue - self.event_queue.publish( - self.uid.hex, - event_name, - event_payload, - publisher_id - ) - - def shutdown(self) -> None: - """Must close any opened resources and call super().shutdown().""" - self.publish_event( - event_name=eventNames.WORK_FINISHED, - event_payload={}, - publisher_id=self.__class__.__name__ - ) - - -class Threadless(multiprocessing.Process): - """Threadless provides an event loop. Use it by implementing Threadless class. - - When --threadless option is enabled, each Acceptor process also - spawns one Threadless process. And instead of spawning new thread - for each accepted client connection, Acceptor process sends - accepted client connection to Threadless process over a pipe. - - HttpProtocolHandler implements ThreadlessWork class and hooks into the - event loop provided by Threadless. - """ - - def __init__( - self, - client_queue: connection.Connection, - flags: Flags, - work_klass: Type[ThreadlessWork], - event_queue: Optional[EventQueue] = None) -> None: - super().__init__() - self.client_queue = client_queue - self.flags = flags - self.work_klass = work_klass - self.event_queue = event_queue - - self.running = multiprocessing.Event() - self.works: Dict[int, ThreadlessWork] = {} - self.selector: Optional[selectors.DefaultSelector] = None - self.loop: Optional[asyncio.AbstractEventLoop] = None - - @contextlib.contextmanager - def selected_events(self) -> Generator[Tuple[List[Union[int, HasFileno]], - List[Union[int, HasFileno]]], - None, None]: - events: Dict[socket.socket, int] = {} - for work in self.works.values(): - events.update(work.get_events()) - assert self.selector is not None - for fd in events: - self.selector.register(fd, events[fd]) - ev = self.selector.select(timeout=1) - readables = [] - writables = [] - for key, mask in ev: - if mask & selectors.EVENT_READ: - readables.append(key.fileobj) - if mask & selectors.EVENT_WRITE: - writables.append(key.fileobj) - yield (readables, writables) - for fd in events.keys(): - self.selector.unregister(fd) - - async def handle_events( - self, fileno: int, - readables: List[Union[int, HasFileno]], - writables: List[Union[int, HasFileno]]) -> bool: - return self.works[fileno].handle_events(readables, writables) - - # TODO: Use correct future typing annotations - async def wait_for_tasks( - self, tasks: Dict[int, Any]) -> None: - for work_id in tasks: - # TODO: Resolving one handle_events here can block resolution of - # other tasks - try: - teardown = await asyncio.wait_for(tasks[work_id], DEFAULT_TIMEOUT) - if teardown: - self.cleanup(work_id) - except asyncio.TimeoutError: - self.cleanup(work_id) - - def fromfd(self, fileno: int) -> socket.socket: - return socket.fromfd( - fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, - type=socket.SOCK_STREAM) - - def accept_client(self) -> None: - addr = self.client_queue.recv() - fileno = recv_handle(self.client_queue) - self.works[fileno] = self.work_klass( - TcpClientConnection(conn=self.fromfd(fileno), addr=addr), - flags=self.flags, - event_queue=self.event_queue - ) - self.works[fileno].publish_event( - event_name=eventNames.WORK_STARTED, - event_payload={'fileno': fileno, 'addr': addr}, - publisher_id=self.__class__.__name__ - ) - try: - self.works[fileno].initialize() - except Exception as e: - logger.exception( - 'Exception occurred during initialization', - exc_info=e) - self.cleanup(fileno) - - def cleanup_inactive(self) -> None: - inactive_works: List[int] = [] - for work_id in self.works: - if self.works[work_id].is_inactive(): - inactive_works.append(work_id) - for work_id in inactive_works: - self.cleanup(work_id) - - def cleanup(self, work_id: int) -> None: - # TODO: HttpProtocolHandler.shutdown can call flush which may block - self.works[work_id].shutdown() - del self.works[work_id] - os.close(work_id) - - def run_once(self) -> None: - assert self.loop is not None - with self.selected_events() as (readables, writables): - if len(readables) == 0 and len(writables) == 0: - # Remove and shutdown inactive connections - self.cleanup_inactive() - return - # Note that selector from now on is idle, - # until all the logic below completes. - # - # Invoke Threadless.handle_events - # TODO: Only send readable / writables that client originally - # registered. - tasks = {} - for fileno in self.works: - tasks[fileno] = self.loop.create_task( - self.handle_events(fileno, readables, writables)) - # Accepted client connection from Acceptor - if self.client_queue in readables: - self.accept_client() - # Wait for Threadless.handle_events to complete - self.loop.run_until_complete(self.wait_for_tasks(tasks)) - # Remove and shutdown inactive connections - self.cleanup_inactive() - - def run(self) -> None: - try: - self.selector = selectors.DefaultSelector() - self.selector.register(self.client_queue, selectors.EVENT_READ) - self.loop = asyncio.get_event_loop() - while not self.running.is_set(): - self.run_once() - except KeyboardInterrupt: - pass - finally: - assert self.selector is not None - self.selector.unregister(self.client_queue) - self.client_queue.close() - assert self.loop is not None - self.loop.close() diff --git a/proxy/plugin/cache/__init__.py b/proxy/core/tls/__init__.py similarity index 70% rename from proxy/plugin/cache/__init__.py rename to proxy/core/tls/__init__.py index f3bfb84b2..1ad8fb2c8 100644 --- a/proxy/plugin/cache/__init__.py +++ b/proxy/core/tls/__init__.py @@ -8,10 +8,12 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from .base import BaseCacheResponsesPlugin -from .cache_responses import CacheResponsesPlugin +from .tls import TlsParser +from .types import tlsContentType, tlsHandshakeType + __all__ = [ - 'BaseCacheResponsesPlugin', - 'CacheResponsesPlugin', + 'TlsParser', + 'tlsContentType', + 'tlsHandshakeType', ] diff --git a/proxy/core/tls/certificate.py b/proxy/core/tls/certificate.py new file mode 100644 index 000000000..f71e495c7 --- /dev/null +++ b/proxy/core/tls/certificate.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Tuple, Optional + + +class TlsCertificate: + """TLS Certificate""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + self.data = raw + return True, raw + + def build(self) -> bytes: + assert self.data + return self.data + + +class TlsCertificateRequest: + """TLS Certificate Request""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + return False, raw + + def build(self) -> bytes: + assert self.data + return self.data + + +class TlsCertificateVerify: + """TLS Certificate Verify""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + return False, raw + + def build(self) -> bytes: + assert self.data + return self.data diff --git a/proxy/core/tls/finished.py b/proxy/core/tls/finished.py new file mode 100644 index 000000000..df9db0625 --- /dev/null +++ b/proxy/core/tls/finished.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Tuple, Optional + + +class TlsFinished: + """TLS Finished""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + return False, raw + + def build(self) -> bytes: + assert self.data + return self.data diff --git a/proxy/core/tls/handshake.py b/proxy/core/tls/handshake.py new file mode 100644 index 000000000..7a03e2471 --- /dev/null +++ b/proxy/core/tls/handshake.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import struct +import logging +from typing import Tuple, Optional + +from .hello import ( + TlsClientHello, TlsServerHello, TlsHelloRequest, TlsServerHelloDone, +) +from .types import tlsHandshakeType +from .finished import TlsFinished +from .certificate import ( + TlsCertificate, TlsCertificateVerify, TlsCertificateRequest, +) +from .key_exchange import TlsClientKeyExchange, TlsServerKeyExchange + + +logger = logging.getLogger(__name__) + + +class TlsHandshake: + """TLS Handshake""" + + def __init__(self) -> None: + self.msg_type: int = tlsHandshakeType.OTHER + self.length: Optional[bytes] = None + self.hello_request: Optional[TlsHelloRequest] = None + self.client_hello: Optional[TlsClientHello] = None + self.server_hello: Optional[TlsServerHello] = None + self.certificate: Optional[TlsCertificate] = None + self.server_key_exchange: Optional[TlsServerKeyExchange] = None + self.certificate_request: Optional[TlsCertificateRequest] = None + self.server_hello_done: Optional[TlsServerHelloDone] = None + self.certificate_verify: Optional[TlsCertificateVerify] = None + self.client_key_exchange: Optional[TlsClientKeyExchange] = None + self.finished: Optional[TlsFinished] = None + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + length = len(raw) + if length < 4: + logger.debug('invalid data, len(raw) = %s', length) + return False, raw + payload_length, = struct.unpack('!I', b'\x00' + raw[1:4]) + self.length = payload_length + if length < 4 + payload_length: + logger.debug( + 'incomplete data, len(raw) = %s, len(payload) = %s', length, payload_length, + ) + return False, raw + # parse + self.msg_type = raw[0] + self.length = raw[1:4] + self.data = raw[: 4 + payload_length] + payload = raw[4: 4 + payload_length] + if self.msg_type == tlsHandshakeType.HELLO_REQUEST: + # parse hello request + self.hello_request = TlsHelloRequest() + self.hello_request.parse(payload) + elif self.msg_type == tlsHandshakeType.CLIENT_HELLO: + # parse client hello + self.client_hello = TlsClientHello() + self.client_hello.parse(payload) + elif self.msg_type == tlsHandshakeType.SERVER_HELLO: + # parse server hello + self.server_hello = TlsServerHello() + self.server_hello.parse(payload) + elif self.msg_type == tlsHandshakeType.CERTIFICATE: + # parse certificate + self.certificate = TlsCertificate() + self.certificate.parse(payload) + elif self.msg_type == tlsHandshakeType.SERVER_KEY_EXCHANGE: + # parse server key exchange + self.server_key_exchange = TlsServerKeyExchange() + self.server_key_exchange.parse(payload) + elif self.msg_type == tlsHandshakeType.CERTIFICATE_REQUEST: + # parse certificate request + self.certificate_request = TlsCertificateRequest() + self.certificate_request.parse(payload) + elif self.msg_type == tlsHandshakeType.SERVER_HELLO_DONE: + # parse server hello done + self.server_hello_done = TlsServerHelloDone() + self.server_hello_done.parse(payload) + elif self.msg_type == tlsHandshakeType.CERTIFICATE_VERIFY: + # parse certificate verify + self.certificate_verify = TlsCertificateVerify() + self.certificate_verify.parse(payload) + elif self.msg_type == tlsHandshakeType.CLIENT_KEY_EXCHANGE: + # parse client key exchange + self.client_key_exchange = TlsClientKeyExchange() + self.client_key_exchange.parse(payload) + elif self.msg_type == tlsHandshakeType.FINISHED: + # parse finished + self.finished = TlsFinished() + self.finished.parse(payload) + return True, raw[4 + payload_length:] + + def build(self) -> bytes: + data = b'' + data += bytes([self.msg_type]) + payload = b'' + if self.msg_type == tlsHandshakeType.CLIENT_HELLO: + assert self.client_hello + payload = self.client_hello.build() + elif self.msg_type == tlsHandshakeType.SERVER_HELLO: + assert self.server_hello + payload = self.server_hello.build() + elif self.msg_type == tlsHandshakeType.CERTIFICATE: + assert self.certificate + payload = self.certificate.build() + elif self.msg_type == tlsHandshakeType.SERVER_KEY_EXCHANGE: + assert self.server_key_exchange + payload = self.server_key_exchange.build() + # calculate length + length = struct.pack('!I', len(payload))[1:] + data += length + data += payload + return data diff --git a/proxy/core/tls/hello.py b/proxy/core/tls/hello.py new file mode 100644 index 000000000..f60c724b2 --- /dev/null +++ b/proxy/core/tls/hello.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import os +import struct +import logging +from typing import Tuple, Optional + +from .pretty import pretty_hexlify + + +logger = logging.getLogger(__name__) + + +class TlsHelloRequest: + """TLS Hello Request""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> None: + self.data = raw + + def build(self) -> bytes: + assert self.data + return self.data + + +class TlsClientHello: + """TLS Client Hello""" + + def __init__(self) -> None: + self.protocol_version: Optional[bytes] = None + self.random: Optional[bytes] = None + self.session_id: Optional[bytes] = None + self.cipher_suite: Optional[bytes] = None + self.compression_method: Optional[bytes] = None + self.extension: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + try: + idx = 0 + length = len(raw) + self.protocol_version = raw[idx:idx + 2] + idx += 2 + self.random = raw[idx:idx + 32] + idx += 32 + session_length = raw[idx] + self.session_id = raw[idx: idx + 1 + session_length] + idx += 1 + session_length + cipher_suite_length, = struct.unpack('!H', raw[idx: idx + 2]) + self.cipher_suite = raw[idx: idx + 2 + cipher_suite_length] + idx += 2 + cipher_suite_length + compression_method_length = raw[idx] + self.compression_method = raw[ + idx: idx + + 1 + compression_method_length + ] + idx += 1 + compression_method_length + # extension + if idx == length: + self.extension = b'' + else: + extension_length, = struct.unpack('!H', raw[idx: idx + 2]) + self.extension = raw[idx: idx + 2 + extension_length] + idx += 2 + extension_length + return True, raw[idx:] + except Exception as e: + logger.exception(e) + return False, raw + + def build(self) -> bytes: + # calculate length + return b''.join([ + bs for bs in ( + self.protocol_version, self.random, self.session_id, self.cipher_suite, + self.compression_method, self.extension, + ) if bs is not None + ]) + + def format(self) -> str: + parts = [] + parts.append( + 'Protocol Version: %s' % ( + pretty_hexlify(self.protocol_version) + if self.protocol_version is not None + else '' + ), + ) + parts.append( + 'Random: %s' % ( + pretty_hexlify(self.random) + if self.random is not None else '' + ), + ) + parts.append( + 'Session ID: %s' % ( + pretty_hexlify(self.session_id) + if self.session_id is not None + else '' + ), + ) + parts.append( + 'Cipher Suite: %s' % ( + pretty_hexlify(self.cipher_suite) + if self.cipher_suite is not None + else '' + ), + ) + parts.append( + 'Compression Method: %s' % ( + pretty_hexlify(self.compression_method) + if self.compression_method is not None + else '' + ), + ) + parts.append( + 'Extension: %s' % ( + pretty_hexlify(self.extension) + if self.extension is not None + else '' + ), + ) + return os.linesep.join(parts) + + +class TlsServerHello: + """TLS Server Hello""" + + def __init__(self) -> None: + self.protocol_version: Optional[bytes] = None + self.random: Optional[bytes] = None + self.session_id: Optional[bytes] = None + self.cipher_suite: Optional[bytes] = None + self.compression_method: Optional[bytes] = None + self.extension: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + try: + idx = 0 + length = len(raw) + self.protocol_version = raw[idx:idx + 2] + idx += 2 + self.random = raw[idx:idx + 32] + idx += 32 + session_length = raw[idx] + self.session_id = raw[idx: idx + 1 + session_length] + idx += 1 + session_length + self.cipher_suite = raw[idx: idx + 2] + idx += 2 + compression_method_length = raw[idx] + self.compression_method = raw[ + idx: idx + + 1 + compression_method_length + ] + idx += 1 + compression_method_length + # extension + if idx == length: + self.extension = b'' + else: + extension_length, = struct.unpack('!H', raw[idx: idx + 2]) + self.extension = raw[idx: idx + 2 + extension_length] + idx += 2 + extension_length + return True, raw[idx:] + except Exception as e: + logger.exception(e) + return False, raw + + def build(self) -> bytes: + return b''.join([ + bs for bs in ( + self.protocol_version, self.random, self.session_id, self.cipher_suite, + self.compression_method, self.extension, + ) if bs is not None + ]) + + def format(self) -> str: + parts = [] + parts.append( + 'Protocol Version: %s' % ( + pretty_hexlify(self.protocol_version) + if self.protocol_version is not None + else '' + ), + ) + parts.append( + 'Random: %s' % ( + pretty_hexlify(self.random) + if self.random is not None + else '' + ), + ) + parts.append( + 'Session ID: %s' % ( + pretty_hexlify(self.session_id) + if self.session_id is not None + else '' + ), + ) + parts.append( + 'Cipher Suite: %s' % ( + pretty_hexlify(self.cipher_suite) + if self.cipher_suite is not None + else '' + ), + ) + parts.append( + 'Compression Method: %s' % ( + pretty_hexlify(self.compression_method) + if self.compression_method is not None + else '' + ), + ) + parts.append( + 'Extension: %s' % ( + pretty_hexlify(self.extension) + if self.extension is not None + else '' + ), + ) + return os.linesep.join(parts) + + +class TlsServerHelloDone: + """TLS Server Hello Done""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + return False, raw + + def build(self) -> bytes: + assert self.data + return self.data diff --git a/proxy/core/tls/key_exchange.py b/proxy/core/tls/key_exchange.py new file mode 100644 index 000000000..cb0059ed4 --- /dev/null +++ b/proxy/core/tls/key_exchange.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Tuple, Optional + + +class TlsServerKeyExchange: + """TLS Server Key Exchange""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + self.data = raw + return True, raw + + def build(self) -> bytes: + assert self.data + return self.data + + +class TlsClientKeyExchange: + """TLS Client Key Exchange""" + + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + return False, raw + + def build(self) -> bytes: + assert self.data + return self.data diff --git a/proxy/core/tls/pretty.py b/proxy/core/tls/pretty.py new file mode 100644 index 000000000..200463e3d --- /dev/null +++ b/proxy/core/tls/pretty.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import binascii + + +def pretty_hexlify(raw: bytes) -> str: + hexlified = binascii.hexlify(raw).decode('utf-8') + return ' '.join([hexlified[i: i+2] for i in range(0, len(hexlified), 2)]) diff --git a/proxy/core/tls/tls.py b/proxy/core/tls/tls.py new file mode 100644 index 000000000..9e5aa89eb --- /dev/null +++ b/proxy/core/tls/tls.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import struct +import logging +from typing import Tuple, Optional + +from .types import tlsContentType +from .handshake import TlsHandshake +from .certificate import TlsCertificate + + +logger = logging.getLogger(__name__) + + +class TlsParser: + """TLS packet parser""" + + def __init__(self) -> None: + self.content_type: int = tlsContentType.OTHER + self.protocol_version: Optional[bytes] = None + self.length: Optional[bytes] = None + # only parse hand shake payload temporary + self.handshake: Optional[TlsHandshake] = None + self.certificate: Optional[TlsCertificate] + + def parse(self, raw: bytes) -> Tuple[bool, bytes]: + """Parse TLS fragmentation. + + References + + https://datatracker.ietf.org/doc/html/rfc5246#page-15 + https://datatracker.ietf.org/doc/html/rfc5077#page-3 + https://datatracker.ietf.org/doc/html/rfc8446#page-10 + """ + length = len(raw) + if length < 5: + logger.debug('invalid data, len(raw) = %s', length) + return False, raw + payload_length, = struct.unpack('!H', raw[3:5]) + if length < 5 + payload_length: + logger.debug( + 'incomplete data, len(raw) = %s, len(payload) = %s', length, payload_length, + ) + return False, raw + # parse + self.content_type = raw[0] + self.protocol_version = raw[1:3] + self.length = raw[3:5] + payload = raw[5:5 + payload_length] + if self.content_type == tlsContentType.HANDSHAKE: + # parse handshake + self.handshake = TlsHandshake() + self.handshake.parse(payload) + return True, raw[5 + payload_length:] + + def build(self) -> bytes: + data = b'' + data += bytes([self.content_type]) + assert self.protocol_version + data += self.protocol_version + payload = b'' + if self.content_type == tlsContentType.HANDSHAKE: + assert self.handshake + payload += self.handshake.build() + length = struct.pack('!H', len(payload)) + data += length + data += payload + return data diff --git a/proxy/core/tls/types.py b/proxy/core/tls/types.py new file mode 100644 index 000000000..640cffe28 --- /dev/null +++ b/proxy/core/tls/types.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import NamedTuple + + +TlsContentType = NamedTuple( + 'TlsContentType', [ + ('CHANGE_CIPHER_SPEC', int), + ('ALERT', int), + ('HANDSHAKE', int), + ('APPLICATION_DATA', int), + ('OTHER', int), + ], +) +tlsContentType = TlsContentType(20, 21, 22, 23, 255) + + +TlsHandshakeType = NamedTuple( + 'TlsHandshakeType', [ + ('HELLO_REQUEST', int), + ('CLIENT_HELLO', int), + ('SERVER_HELLO', int), + ('CERTIFICATE', int), + ('SERVER_KEY_EXCHANGE', int), + ('CERTIFICATE_REQUEST', int), + ('SERVER_HELLO_DONE', int), + ('CERTIFICATE_VERIFY', int), + ('CLIENT_KEY_EXCHANGE', int), + ('FINISHED', int), + ('OTHER', int), + ], +) +tlsHandshakeType = TlsHandshakeType(0, 1, 2, 11, 12, 13, 14, 15, 16, 20, 255) diff --git a/proxy/core/work/__init__.py b/proxy/core/work/__init__.py new file mode 100644 index 000000000..dee3296af --- /dev/null +++ b/proxy/core/work/__init__.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + pre +""" +from .pool import ThreadlessPool +from .work import Work +from .local import BaseLocalExecutor +from .remote import BaseRemoteExecutor +from .delegate import delegate_work_to_pool +from .threaded import start_threaded_work +from .threadless import Threadless + + +__all__ = [ + 'Work', + 'Threadless', + 'ThreadlessPool', + 'delegate_work_to_pool', + 'start_threaded_work', + 'BaseLocalExecutor', + 'BaseRemoteExecutor', +] diff --git a/proxy/core/work/delegate.py b/proxy/core/work/delegate.py new file mode 100644 index 000000000..76f7e71eb --- /dev/null +++ b/proxy/core/work/delegate.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import TYPE_CHECKING, Optional +from multiprocessing.reduction import send_handle + + +if TYPE_CHECKING: # pragma: no cover + import socket + import multiprocessing + from multiprocessing import connection + + from ...common.types import HostPort + + +def delegate_work_to_pool( + worker_pid: int, + work_queue: 'connection.Connection', + work_lock: 'multiprocessing.synchronize.Lock', + conn: 'socket.socket', + addr: Optional['HostPort'], + unix_socket_path: Optional[str] = None, +) -> None: + """Utility method to delegate a work to threadless executor pool.""" + with work_lock: + # Accepted client address is empty string for + # unix socket domain, avoid sending empty string + # for optimization. + if not unix_socket_path: + work_queue.send(addr) + send_handle( + work_queue, + conn.fileno(), + worker_pid, + ) + conn.close() diff --git a/proxy/core/work/fd/__init__.py b/proxy/core/work/fd/__init__.py new file mode 100644 index 000000000..f277bd5a8 --- /dev/null +++ b/proxy/core/work/fd/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from .fd import ThreadlessFdExecutor +from .local import LocalFdExecutor +from .remote import RemoteFdExecutor + + +__all__ = [ + 'ThreadlessFdExecutor', + 'LocalFdExecutor', + 'RemoteFdExecutor', +] diff --git a/proxy/core/work/fd/fd.py b/proxy/core/work/fd/fd.py new file mode 100644 index 000000000..cb6e903d7 --- /dev/null +++ b/proxy/core/work/fd/fd.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import socket +import logging +from typing import Any, TypeVar, Optional + +from ...event import eventNames +from ..threadless import Threadless +from ....common.types import HostPort, TcpOrTlsSocket + + +T = TypeVar('T') + +logger = logging.getLogger(__name__) + + +class ThreadlessFdExecutor(Threadless[T]): + """A threadless executor which handles file descriptors + and works with read/write events over a socket.""" + + def work(self, *args: Any) -> None: + fileno: int = args[0] + addr: Optional[HostPort] = args[1] + conn: Optional[TcpOrTlsSocket] = args[2] + conn = conn or socket.fromfd( + fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, + type=socket.SOCK_STREAM, + ) + uid = '%s-%s-%s' % (self.iid, self._total, fileno) + self.works[fileno] = self.create(uid, conn, addr) + self.works[fileno].publish_event( + event_name=eventNames.WORK_STARTED, + event_payload={'fileno': fileno, 'addr': addr}, + publisher_id=self.__class__.__name__, + ) + try: + self.works[fileno].initialize() + self._total += 1 + except Exception as e: + logger.exception( # pragma: no cover + 'Exception occurred during initialization', + exc_info=e, + ) + self._cleanup(fileno) diff --git a/proxy/core/work/fd/local.py b/proxy/core/work/fd/local.py new file mode 100644 index 000000000..65fae31c6 --- /dev/null +++ b/proxy/core/work/fd/local.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import queue +import asyncio +import contextlib +from typing import Any, Optional + +from .fd import ThreadlessFdExecutor +from ....common.backports import NonBlockingQueue + + +class LocalFdExecutor(ThreadlessFdExecutor[NonBlockingQueue]): + """A threadless executor implementation which uses a queue to receive new work.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + if self._loop is None: + self._loop = asyncio.get_event_loop_policy().new_event_loop() + return self._loop + + def work_queue_fileno(self) -> Optional[int]: + return None + + def receive_from_work_queue(self) -> bool: + with contextlib.suppress(queue.Empty): + work = self.work_queue.get() + if isinstance(work, bool) and work is False: + return True + self.initialize(work) + return False + + def initialize(self, work: Any) -> None: + assert isinstance(work, tuple) + conn, addr = work + # NOTE: Here we are assuming to receive a connection object + # and not a fileno because we are a LocalExecutor. + fileno = conn.fileno() + self.work(fileno, addr, conn) diff --git a/proxy/core/work/fd/remote.py b/proxy/core/work/fd/remote.py new file mode 100644 index 000000000..fdad0ab32 --- /dev/null +++ b/proxy/core/work/fd/remote.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import asyncio +from typing import Any, Optional +from multiprocessing import connection +from multiprocessing.reduction import recv_handle + +from .fd import ThreadlessFdExecutor + + +class RemoteFdExecutor(ThreadlessFdExecutor[connection.Connection]): + """A threadless executor implementation which receives work over a connection. + + NOTE: RemoteExecutor uses ``recv_handle`` to accept file descriptors. + + TODO: Refactor and abstract ``recv_handle`` part so that a threaded + remote executor can also accept work over a connection. Currently, + remote executors must be running in a process. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + if self._loop is None: + self._loop = asyncio.get_event_loop_policy().get_event_loop() + return self._loop + + def receive_from_work_queue(self) -> bool: + # Acceptor will not send address for + # unix socket domain environments. + addr = None + if not self.flags.unix_socket_path: + addr = self.work_queue.recv() + fileno = recv_handle(self.work_queue) + self.work(fileno, addr, None) + return False + + def work_queue_fileno(self) -> Optional[int]: + return self.work_queue.fileno() + + def close_work_queue(self) -> None: + self.work_queue.close() diff --git a/proxy/core/work/local.py b/proxy/core/work/local.py new file mode 100644 index 000000000..0745e817a --- /dev/null +++ b/proxy/core/work/local.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import queue +import asyncio +import contextlib +from typing import Any, Optional + +from .threadless import Threadless +from ...common.backports import NonBlockingQueue + + +class BaseLocalExecutor(Threadless[NonBlockingQueue]): + """A threadless executor implementation which uses a queue to receive new work.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + if self._loop is None: + self._loop = asyncio.get_event_loop_policy().new_event_loop() + return self._loop + + def work_queue_fileno(self) -> Optional[int]: + return None + + def receive_from_work_queue(self) -> bool: + with contextlib.suppress(queue.Empty): + work = self.work_queue.get() + if isinstance(work, bool) and work is False: + return True + self.work(work) + return False diff --git a/proxy/core/work/pool.py b/proxy/core/work/pool.py new file mode 100644 index 000000000..5458f0a89 --- /dev/null +++ b/proxy/core/work/pool.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import logging +import argparse +import multiprocessing +from typing import TYPE_CHECKING, Any, List, Type, TypeVar, Optional +from multiprocessing import connection + +from ...common.flag import flags +from ...common.constants import DEFAULT_THREADLESS, DEFAULT_NUM_WORKERS + + +if TYPE_CHECKING: # pragma: no cover + from ..event import EventQueue + from .threadless import Threadless + +T = TypeVar('T', bound='Threadless[Any]') + +logger = logging.getLogger(__name__) + + +flags.add_argument( + '--threadless', + action='store_true', + default=DEFAULT_THREADLESS, + help='Default: ' + ('True' if DEFAULT_THREADLESS else 'False') + '. ' + + 'Enabled by default on Python 3.8+ (mac, linux). ' + + 'When disabled a new thread is spawned ' + 'to handle each client connection.', +) + +flags.add_argument( + '--threaded', + action='store_true', + default=not DEFAULT_THREADLESS, + help='Default: ' + ('True' if not DEFAULT_THREADLESS else 'False') + '. ' + + 'Disabled by default on Python < 3.8 and windows. ' + + 'When enabled a new thread is spawned ' + 'to handle each client connection.', +) + +flags.add_argument( + '--num-workers', + type=int, + default=DEFAULT_NUM_WORKERS, + help='Defaults to number of CPU cores.', +) + + +class ThreadlessPool: + """Manages lifecycle of threadless pool and delegates work to them + using a round-robin strategy. + + Example usage:: + + with ThreadlessPool(flags=...) as pool: + while True: + time.sleep(1) + + If necessary, start multiple threadless pool with different + work classes. + """ + + def __init__( + self, + flags: argparse.Namespace, + executor_klass: Type['T'], + event_queue: Optional['EventQueue'] = None, + ) -> None: + self.flags = flags + self.event_queue = event_queue + # Threadless worker communication states + self.work_queues: List[connection.Connection] = [] + self.work_pids: List[int] = [] + self.work_locks: List['multiprocessing.synchronize.Lock'] = [] + # List of threadless workers + self._executor_klass = executor_klass + # FIXME: Instead of Any type must be the executor klass + self._workers: List[Any] = [] + self._processes: List[multiprocessing.Process] = [] + + def __enter__(self) -> 'ThreadlessPool': + self.setup() + return self + + def __exit__(self, *args: Any) -> None: + self.shutdown() + + def setup(self) -> None: + """Setup threadless processes.""" + if self.flags.threadless: + for index in range(self.flags.num_workers): + self._start_worker(index) + logger.info( + 'Started {0} threadless workers'.format( + self.flags.num_workers, + ), + ) + + def shutdown(self) -> None: + """Shutdown threadless processes.""" + if self.flags.threadless: + self._shutdown_workers() + logger.info( + 'Stopped {0} threadless workers'.format( + self.flags.num_workers, + ), + ) + + def _start_worker(self, index: int) -> None: + """Starts a threadless worker.""" + self.work_locks.append(multiprocessing.Lock()) + pipe = multiprocessing.Pipe() + self.work_queues.append(pipe[0]) + w = self._executor_klass( + iid=str(index), + work_queue=pipe[1], + flags=self.flags, + event_queue=self.event_queue, + ) + self._workers.append(w) + p = multiprocessing.Process(target=w.run) + # p.daemon = True + self._processes.append(p) + p.start() + assert p.pid + self.work_pids.append(p.pid) + logger.debug('Started threadless#%d process#%d', index, p.pid) + + def _shutdown_workers(self) -> None: + """Pop a running threadless worker and clean it up.""" + for index in range(self.flags.num_workers): + self._workers[index].running.set() + for _ in range(self.flags.num_workers): + pid = self.work_pids[-1] + self._processes.pop().join() + self._workers.pop() + self.work_pids.pop() + self.work_queues.pop().close() + logger.debug('Stopped threadless process#%d', pid) + self.work_locks = [] diff --git a/proxy/core/work/remote.py b/proxy/core/work/remote.py new file mode 100644 index 000000000..afac2ebef --- /dev/null +++ b/proxy/core/work/remote.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import asyncio +from typing import Any, Optional +from multiprocessing import connection + +from .threadless import Threadless + + +class BaseRemoteExecutor(Threadless[connection.Connection]): + """A threadless executor implementation which receives work over a connection.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + if self._loop is None: + self._loop = asyncio.get_event_loop_policy().get_event_loop() + return self._loop + + def work_queue_fileno(self) -> Optional[int]: + return self.work_queue.fileno() + + def close_work_queue(self) -> None: + self.work_queue.close() + + def receive_from_work_queue(self) -> bool: + self.work(self.work_queue.recv()) + return False diff --git a/proxy/core/work/task/__init__.py b/proxy/core/work/task/__init__.py new file mode 100644 index 000000000..157ae566d --- /dev/null +++ b/proxy/core/work/task/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from .task import Task +from .local import LocalTaskExecutor, ThreadedTaskExecutor +from .remote import RemoteTaskExecutor, SingleProcessTaskExecutor +from .handler import TaskHandler + + +__all__ = [ + 'Task', + 'TaskHandler', + 'LocalTaskExecutor', + 'ThreadedTaskExecutor', + 'RemoteTaskExecutor', + 'SingleProcessTaskExecutor', +] diff --git a/proxy/core/work/task/handler.py b/proxy/core/work/task/handler.py new file mode 100644 index 000000000..5fd78e383 --- /dev/null +++ b/proxy/core/work/task/handler.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Any + +from .task import Task +from ..work import Work + + +class TaskHandler(Work[Task]): + """Task handler.""" + + @staticmethod + def create(*args: Any) -> Task: + """Work core doesn't know how to create work objects for us. + Example, for task module scenario, it doesn't know how to create + Task objects for us.""" + return Task(*args) diff --git a/proxy/core/work/task/local.py b/proxy/core/work/task/local.py new file mode 100644 index 000000000..a2642b23f --- /dev/null +++ b/proxy/core/work/task/local.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import uuid +import threading +from typing import Any + +from ..local import BaseLocalExecutor +from ....common.backports import NonBlockingQueue + + +class LocalTaskExecutor(BaseLocalExecutor): + """We'll define a local executor which is capable of receiving + log lines over a non blocking queue.""" + + def work(self, *args: Any) -> None: + task_id = int(time.time()) + uid = '%s-%s' % (self.iid, task_id) + self.works[task_id] = self.create(uid, *args) + + +class ThreadedTaskExecutor(threading.Thread): + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.daemon = True + self.executor = LocalTaskExecutor( + iid=uuid.uuid4().hex, + work_queue=NonBlockingQueue(), + **kwargs, + ) + + def __enter__(self) -> 'ThreadedTaskExecutor': + self.start() + return self + + def __exit__(self, *args: Any) -> None: + self.executor.running.set() + self.join() + + def run(self) -> None: + self.executor.run() diff --git a/proxy/core/work/task/remote.py b/proxy/core/work/task/remote.py new file mode 100644 index 000000000..ce4b0009d --- /dev/null +++ b/proxy/core/work/task/remote.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import uuid +import multiprocessing +from typing import Any + +from ..remote import BaseRemoteExecutor + + +class RemoteTaskExecutor(BaseRemoteExecutor): + + def work(self, *args: Any) -> None: + task_id = int(time.time()) + uid = '%s-%s' % (self.iid, task_id) + self.works[task_id] = self.create(uid, *args) + + +class SingleProcessTaskExecutor(multiprocessing.Process): + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.daemon = True + self.work_queue, remote = multiprocessing.Pipe() + self.executor = RemoteTaskExecutor( + iid=uuid.uuid4().hex, + work_queue=remote, + **kwargs, + ) + + def __enter__(self) -> 'SingleProcessTaskExecutor': + self.start() + return self + + def __exit__(self, *args: Any) -> None: + self.executor.running.set() + self.join() + + def run(self) -> None: + self.executor.run() diff --git a/proxy/core/work/task/task.py b/proxy/core/work/task/task.py new file mode 100644 index 000000000..f4467ef2c --- /dev/null +++ b/proxy/core/work/task/task.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" + + +class Task: + """Task object which known how to process the payload.""" + + def __init__(self, payload: bytes) -> None: + self.payload = payload + print(payload) diff --git a/proxy/core/work/threaded.py b/proxy/core/work/threaded.py new file mode 100644 index 000000000..74899e8cb --- /dev/null +++ b/proxy/core/work/threaded.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import socket +import argparse +import threading +from typing import TYPE_CHECKING, Tuple, TypeVar, Optional + +from ..event import EventQueue, eventNames + + +if TYPE_CHECKING: # pragma: no cover + from .work import Work + from ...common.types import HostPort + +T = TypeVar('T') + + +# TODO: Add generic T +def start_threaded_work( + flags: argparse.Namespace, + conn: socket.socket, + addr: Optional['HostPort'], + event_queue: Optional[EventQueue] = None, + publisher_id: Optional[str] = None, +) -> Tuple['Work[T]', threading.Thread]: + """Utility method to start a work in a new thread.""" + work = flags.work_klass( + flags.work_klass.create(conn, addr), + flags=flags, + event_queue=event_queue, + upstream_conn_pool=None, + ) + # TODO: Keep reference to threads and join during shutdown. + # This will ensure connections are not abruptly closed on shutdown + # for threaded execution mode. + thread = threading.Thread(target=work.run) + thread.daemon = True + thread.start() + work.publish_event( + event_name=eventNames.WORK_STARTED, + event_payload={'fileno': conn.fileno(), 'addr': addr}, + publisher_id=publisher_id or 'thread#{0}'.format( + thread.ident, + ), + ) + return (work, thread) diff --git a/proxy/core/work/threadless.py b/proxy/core/work/threadless.py new file mode 100644 index 000000000..f43c0a473 --- /dev/null +++ b/proxy/core/work/threadless.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import os +import asyncio +import logging +import argparse +import selectors +import multiprocessing +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, Any, Set, Dict, List, Tuple, Generic, TypeVar, Optional, + cast, +) + +from ...common.types import Readables, Writables, SelectableEvents +from ...common.logger import Logger +from ...common.constants import ( + DEFAULT_WAIT_FOR_TASKS_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT, + DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT, +) + + +if TYPE_CHECKING: # pragma: no cover + from .work import Work + from ..event import EventQueue + +T = TypeVar('T') + +logger = logging.getLogger(__name__) + + +class Threadless(ABC, Generic[T]): + """Work executor base class. + + Threadless provides an event loop, which is shared across + multiple :class:`~proxy.core.acceptor.work.Work` instances to handle + work. + + Threadless takes input a `work_klass` and an `event_queue`. `work_klass` + must conform to the :class:`~proxy.core.acceptor.work.Work` + protocol. Work is received over the `event_queue`. + + When a work is accepted, threadless creates a new instance of `work_klass`. + Threadless will then invoke necessary lifecycle of the + :class:`~proxy.core.acceptor.work.Work` protocol, + allowing `work_klass` implementation to handle the assigned work. + + Example, :class:`~proxy.core.base.tcp_server.BaseTcpServerHandler` + implements :class:`~proxy.core.acceptor.work.Work` protocol. It + expects a client connection as work payload and hooks into the + threadless event loop to handle the client connection. + """ + + def __init__( + self, + iid: str, + work_queue: T, + flags: argparse.Namespace, + event_queue: Optional['EventQueue'] = None, + ) -> None: + super().__init__() + self.iid = iid + self.work_queue = work_queue + self.flags = flags + self.event_queue = event_queue + + self.running = multiprocessing.Event() + self.works: Dict[int, 'Work[Any]'] = {} + self.selector: Optional[selectors.DefaultSelector] = None + # If we remove single quotes for typing hint below, + # runtime exceptions will occur for < Python 3.9. + # + # Ref https://github.com/abhinavsingh/proxy.py/runs/4279055360?check_suite_focus=true + self.unfinished: Set['asyncio.Task[bool]'] = set() + self.registered_events_by_work_ids: Dict[ + # work_id + int, + # fileno, mask + SelectableEvents, + ] = {} + self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT + self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT + self._total: int = 0 + # When put at the top, causes circular import error + # since integrated ssh tunnel was introduced. + from ..connection import ( # pylint: disable=C0415 + UpstreamConnectionPool, + ) + self._upstream_conn_pool: Optional['UpstreamConnectionPool'] = None + self._upstream_conn_filenos: Set[int] = set() + if self.flags.enable_conn_pool: + self._upstream_conn_pool = UpstreamConnectionPool() + + @property + @abstractmethod + def loop(self) -> Optional[asyncio.AbstractEventLoop]: + raise NotImplementedError() + + @abstractmethod + def receive_from_work_queue(self) -> bool: + """Work queue is ready to receive new work. + + Receive it and call ``work_on_tcp_conn``. + + Return True to tear down the loop.""" + raise NotImplementedError() + + @abstractmethod + def work_queue_fileno(self) -> Optional[int]: + """If work queue must be selected before calling + ``receive_from_work_queue`` then implementation must + return work queue fd.""" + raise NotImplementedError() + + @abstractmethod + def work(self, *args: Any) -> None: + raise NotImplementedError() + + def create(self, uid: str, *args: Any) -> 'Work[T]': + return cast( + 'Work[T]', self.flags.work_klass( + self.flags.work_klass.create(*args), + flags=self.flags, + event_queue=self.event_queue, + uid=uid, + upstream_conn_pool=self._upstream_conn_pool, + ), + ) + + def close_work_queue(self) -> None: + """Only called if ``work_queue_fileno`` returns an integer. + If an fd is select-able for work queue, make sure + to close the work queue fd now.""" + pass # pragma: no cover + + async def _update_work_events(self, work_id: int) -> None: + assert self.selector is not None + worker_events = await self.works[work_id].get_events() + # NOTE: Current assumption is that multiple works will not + # be interested in the same fd. Descriptors of interests + # returned by work must be unique. + # + # TODO: Ideally we must diff and unregister socks not + # returned of interest within current _select_events call + # but exists in the registered_socks_by_work_ids registry. + for fileno in worker_events: + if work_id not in self.registered_events_by_work_ids: + self.registered_events_by_work_ids[work_id] = {} + mask = worker_events[fileno] + if fileno in self.registered_events_by_work_ids[work_id]: + oldmask = self.registered_events_by_work_ids[work_id][fileno] + if mask != oldmask: + self.selector.modify( + fileno, events=mask, + data=work_id, + ) + self.registered_events_by_work_ids[work_id][fileno] = mask + logger.debug( + 'fd#{0} modified for mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) + # else: + # logger.info( + # 'fd#{0} by work#{1} not modified'.format(fileno, work_id)) + elif fileno in self._upstream_conn_filenos: + # Descriptor offered by work, but is already registered by connection pool + # Most likely because work has acquired a reusable connection. + self.selector.modify(fileno, events=mask, data=work_id) + self.registered_events_by_work_ids[work_id][fileno] = mask + self._upstream_conn_filenos.remove(fileno) + logger.debug( + 'fd#{0} borrowed with mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) + # Can throw ValueError: Invalid file descriptor: -1 + # + # A guard within Work classes may not help here due to + # asynchronous nature. Hence, threadless will handle + # ValueError exceptions raised by selector.register + # for invalid fd. + # + # TODO: Also remove offending work from pool to avoid spin loop. + elif fileno != -1: + self.selector.register(fileno, events=mask, data=work_id) + self.registered_events_by_work_ids[work_id][fileno] = mask + logger.debug( + 'fd#{0} registered for mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) + + async def _update_conn_pool_events(self) -> None: + if not self._upstream_conn_pool: + return + assert self.selector is not None + new_conn_pool_events = await self._upstream_conn_pool.get_events() + old_conn_pool_filenos = self._upstream_conn_filenos.copy() + self._upstream_conn_filenos.clear() + new_conn_pool_filenos = set(new_conn_pool_events.keys()) + new_conn_pool_filenos.difference_update(old_conn_pool_filenos) + for fileno in new_conn_pool_filenos: + self.selector.register( + fileno, + events=new_conn_pool_events[fileno], + data=0, + ) + self._upstream_conn_filenos.add(fileno) + old_conn_pool_filenos.difference_update(self._upstream_conn_filenos) + for fileno in old_conn_pool_filenos: + self.selector.unregister(fileno) + + async def _update_selector(self) -> None: + assert self.selector is not None + unfinished_work_ids = set() + for task in self.unfinished: + unfinished_work_ids.add(task._work_id) # type: ignore + for work_id in self.works: + # We don't want to invoke work objects which haven't + # yet finished their previous task + if work_id in unfinished_work_ids: + continue + await self._update_work_events(work_id) + await self._update_conn_pool_events() + + async def _selected_events(self) -> Tuple[ + Dict[int, Tuple[Readables, Writables]], + bool, + ]: + """For each work, collects events that they are interested in. + Calls select for events of interest. + + Returns a 2-tuple containing a dictionary and boolean. + Dictionary keys are work IDs and values are 2-tuple + containing ready readables & writables. + + Returned boolean value indicates whether there is + a newly accepted work waiting to be received and + queued for processing. This is only applicable when + :class:`~proxy.core.work.threadless.Threadless.work_queue_fileno` + returns a valid fd. + """ + assert self.selector is not None + await self._update_selector() + # Keys are work_id and values are 2-tuple indicating + # readables & writables that work_id is interested in + # and are ready for IO. + work_by_ids: Dict[int, Tuple[Readables, Writables]] = {} + new_work_available = False + wqfileno = self.work_queue_fileno() + if wqfileno is None: + # When ``work_queue_fileno`` returns None, + # always return True for the boolean value. + new_work_available = True + + events = self.selector.select( + timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT, + ) + + for key, mask in events: + if not new_work_available and wqfileno is not None and key.fileobj == wqfileno: + assert mask & selectors.EVENT_READ + new_work_available = True + continue + if key.data not in work_by_ids: + work_by_ids[key.data] = ([], []) + if mask & selectors.EVENT_READ: + work_by_ids[key.data][0].append(key.fd) + if mask & selectors.EVENT_WRITE: + work_by_ids[key.data][1].append(key.fd) + return (work_by_ids, new_work_available) + + async def _wait_for_tasks(self) -> Set['asyncio.Task[bool]']: + finished, self.unfinished = await asyncio.wait( + self.unfinished, + timeout=self.wait_timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + return finished # noqa: WPS331 + + def _cleanup_inactive(self) -> None: + inactive_works: List[int] = [] + for work_id in self.works: + if self.works[work_id].is_inactive(): + inactive_works.append(work_id) + for work_id in inactive_works: + self._cleanup(work_id) + + # TODO: HttpProtocolHandler.shutdown can call flush which may block + def _cleanup(self, work_id: int) -> None: + if work_id in self.registered_events_by_work_ids: + assert self.selector + for fileno in self.registered_events_by_work_ids[work_id]: + logger.debug( + 'fd#{0} unregistered by work#{1}'.format( + fileno, work_id, + ), + ) + self.selector.unregister(fileno) + self.registered_events_by_work_ids[work_id].clear() + del self.registered_events_by_work_ids[work_id] + self.works[work_id].shutdown() + del self.works[work_id] + if self.work_queue_fileno() is not None: + os.close(work_id) + + def _create_tasks( + self, + work_by_ids: Dict[int, Tuple[Readables, Writables]], + ) -> Set['asyncio.Task[bool]']: + assert self.loop + tasks: Set['asyncio.Task[bool]'] = set() + for work_id in work_by_ids: + if work_id == 0: + assert self._upstream_conn_pool + task = self.loop.create_task( + self._upstream_conn_pool.handle_events( + *work_by_ids[work_id], + ), + ) + else: + task = self.loop.create_task( + self.works[work_id].handle_events(*work_by_ids[work_id]), + ) + task._work_id = work_id # type: ignore[attr-defined] + # task.set_name(work_id) + tasks.add(task) + return tasks + + async def _run_once(self) -> bool: + assert self.loop is not None + work_by_ids, new_work_available = await self._selected_events() + # Accept new work if available + # + # TODO: We must use a work klass to handle + # client_queue fd itself a.k.a. accept_client + # will become handle_readables. + if new_work_available: + teardown = self.receive_from_work_queue() + if teardown: + return teardown + if len(work_by_ids) == 0: + return False + # Invoke Threadless.handle_events + self.unfinished.update(self._create_tasks(work_by_ids)) + # logger.debug('Executing {0} works'.format(len(self.unfinished))) + # Cleanup finished tasks + for task in await self._wait_for_tasks(): + # Checking for result can raise exception e.g. + # CancelledError, InvalidStateError or an exception + # from underlying task e.g. TimeoutError. + teardown = False + work_id = task._work_id # type: ignore + try: + teardown = task.result() + finally: + if teardown: + self._cleanup(work_id) + # self.cleanup(int(task.get_name())) + # logger.debug( + # 'Done executing works, {0} pending, {1} registered'.format( + # len(self.unfinished), len(self.registered_events_by_work_ids), + # ), + # ) + return False + + async def _run_forever(self) -> None: + tick = 0 + try: + while True: + if await self._run_once(): + break + # Check for inactive and shutdown signal + elapsed = tick * \ + (DEFAULT_SELECTOR_SELECT_TIMEOUT + self.wait_timeout) + if elapsed >= self.cleanup_inactive_timeout: + self._cleanup_inactive() + if self.running.is_set(): + break + tick = 0 + tick += 1 + except KeyboardInterrupt: + pass + finally: + if self.loop: + self.loop.stop() + + def run(self) -> None: + Logger.setup( + self.flags.log_file, self.flags.log_level, + self.flags.log_format, + ) + wqfileno = self.work_queue_fileno() + try: + self.selector = selectors.DefaultSelector() + if wqfileno is not None: + self.selector.register( + wqfileno, + selectors.EVENT_READ, + data=wqfileno, + ) + assert self.loop + logger.debug('Working on {0} works'.format(len(self.works))) + self.loop.create_task(self._run_forever()) + self.loop.run_forever() + except KeyboardInterrupt: + pass + finally: + assert self.selector is not None + if wqfileno is not None: + self.selector.unregister(wqfileno) + self.close_work_queue() + self.selector.close() + assert self.loop is not None + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() diff --git a/proxy/core/work/work.py b/proxy/core/work/work.py new file mode 100644 index 000000000..d68969a72 --- /dev/null +++ b/proxy/core/work/work.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + acceptor +""" +import argparse +from abc import ABC, abstractmethod +from uuid import uuid4 +from typing import TYPE_CHECKING, Any, Dict, Generic, TypeVar, Optional + +from ..event import EventQueue, eventNames +from ...common.types import Readables, Writables, SelectableEvents + + +if TYPE_CHECKING: # pragma: no cover + from ..connection import UpstreamConnectionPool + +T = TypeVar('T') + + +class Work(ABC, Generic[T]): + """Implement Work to hook into the event loop provided by Threadless process.""" + + def __init__( + self, + work: T, + flags: argparse.Namespace, + event_queue: Optional[EventQueue] = None, + uid: Optional[str] = None, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, + ) -> None: + # Work uuid + self.uid: str = uid if uid is not None else uuid4().hex + self.flags = flags + # Eventing core queue + self.event_queue = event_queue + # Accept work + self.work = work + self.upstream_conn_pool = upstream_conn_pool + + @staticmethod + @abstractmethod + def create(*args: Any) -> T: + """Implementations are responsible for creation of work objects + from incoming args. This helps keep work core agnostic to + creation of externally defined work class objects.""" + raise NotImplementedError() + + async def get_events(self) -> SelectableEvents: + """Return sockets and events (read or write) that we are interested in.""" + return {} # pragma: no cover + + async def handle_events( + self, + _readables: Readables, + _writables: Writables, + ) -> bool: + """Handle readable and writable sockets. + + Return True to shutdown work.""" + return False # pragma: no cover + + def initialize(self) -> None: + """Perform any resource initialization.""" + pass # pragma: no cover + + def is_inactive(self) -> bool: + """Return True if connection should be considered inactive.""" + return False # pragma: no cover + + def shutdown(self) -> None: + """Implementation must close any opened resources here + and call super().shutdown().""" + self.publish_event( + event_name=eventNames.WORK_FINISHED, + event_payload={}, + publisher_id=self.__class__.__name__, + ) + + def run(self) -> None: + """run() method is not used by Threadless. It's here for backward + compatibility with threaded mode where work class is started as + a separate thread. + """ + pass # pragma: no cover + + def publish_event( + self, + event_name: int, + event_payload: Dict[str, Any], + publisher_id: Optional[str] = None, + ) -> None: + """Convenience method provided to publish events into the global event queue.""" + if not self.flags.enable_events: + return + assert self.event_queue + self.event_queue.publish( + self.uid, + event_name, + event_payload, + publisher_id, + ) diff --git a/proxy/http/__init__.py b/proxy/http/__init__.py index 232621f0b..37826426a 100644 --- a/proxy/http/__init__.py +++ b/proxy/http/__init__.py @@ -8,3 +8,21 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +from .url import Url +from .codes import httpStatusCodes +from .plugin import HttpProtocolHandlerPlugin +from .handler import HttpProtocolHandler +from .headers import httpHeaders +from .methods import httpMethods +from .connection import HttpClientConnection + + +__all__ = [ + 'HttpProtocolHandler', + 'HttpClientConnection', + 'HttpProtocolHandlerPlugin', + 'httpStatusCodes', + 'httpMethods', + 'httpHeaders', + 'Url', +] diff --git a/proxy/http/codes.py b/proxy/http/codes.py index 042d27e4f..ad6716090 100644 --- a/proxy/http/codes.py +++ b/proxy/http/codes.py @@ -7,41 +7,49 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable """ from typing import NamedTuple -HttpStatusCodes = NamedTuple('HttpStatusCodes', [ - # 1xx - ('CONTINUE', int), - ('SWITCHING_PROTOCOLS', int), - # 2xx - ('OK', int), - # 3xx - ('MOVED_PERMANENTLY', int), - ('SEE_OTHER', int), - ('TEMPORARY_REDIRECT', int), - ('PERMANENT_REDIRECT', int), - # 4xx - ('BAD_REQUEST', int), - ('UNAUTHORIZED', int), - ('FORBIDDEN', int), - ('NOT_FOUND', int), - ('PROXY_AUTH_REQUIRED', int), - ('REQUEST_TIMEOUT', int), - ('I_AM_A_TEAPOT', int), - # 5xx - ('INTERNAL_SERVER_ERROR', int), - ('NOT_IMPLEMENTED', int), - ('BAD_GATEWAY', int), - ('GATEWAY_TIMEOUT', int), - ('NETWORK_READ_TIMEOUT_ERROR', int), - ('NETWORK_CONNECT_TIMEOUT_ERROR', int), -]) +HttpStatusCodes = NamedTuple( + 'HttpStatusCodes', [ + # 1xx + ('CONTINUE', int), + ('SWITCHING_PROTOCOLS', int), + # 2xx + ('OK', int), + # 3xx + ('MOVED_PERMANENTLY', int), + ('SEE_OTHER', int), + ('TEMPORARY_REDIRECT', int), + ('PERMANENT_REDIRECT', int), + # 4xx + ('BAD_REQUEST', int), + ('UNAUTHORIZED', int), + ('FORBIDDEN', int), + ('NOT_FOUND', int), + ('PROXY_AUTH_REQUIRED', int), + ('REQUEST_TIMEOUT', int), + ('I_AM_A_TEAPOT', int), + # 5xx + ('INTERNAL_SERVER_ERROR', int), + ('NOT_IMPLEMENTED', int), + ('BAD_GATEWAY', int), + ('GATEWAY_TIMEOUT', int), + ('NETWORK_READ_TIMEOUT_ERROR', int), + ('NETWORK_CONNECT_TIMEOUT_ERROR', int), + ], +) + httpStatusCodes = HttpStatusCodes( 100, 101, 200, 301, 303, 307, 308, 400, 401, 403, 404, 407, 408, 418, - 500, 501, 502, 504, 598, 599 + 500, 501, 502, 504, 598, 599, ) diff --git a/proxy/http/connection.py b/proxy/http/connection.py new file mode 100644 index 000000000..d31d34036 --- /dev/null +++ b/proxy/http/connection.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable +""" +from ..core.connection import TcpClientConnection + + +class HttpClientConnection(TcpClientConnection): + pass diff --git a/proxy/http/descriptors.py b/proxy/http/descriptors.py new file mode 100644 index 000000000..ef73496a5 --- /dev/null +++ b/proxy/http/descriptors.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from ..common.types import Readables, Writables, Descriptors + + +# Since 3.4.0 +class DescriptorsHandlerMixin: + """DescriptorsHandlerMixin provides abstraction used by several core HTTP modules + include web and proxy plugins. By using DescriptorsHandlerMixin, class + becomes complaint with core event loop.""" + + # @abstractmethod + async def get_descriptors(self) -> Descriptors: + """Implementations must return a list of descriptions that they wish to + read from and write into.""" + return [], [] # pragma: no cover + + # @abstractmethod + async def write_to_descriptors(self, w: Writables) -> bool: + """Implementations must now write/flush data over the socket. + + Note that buffer management is in-build into the connection classes. + Hence implementations MUST call + :meth:`~proxy.core.connection.connection.TcpConnection.flush` + here, to send any buffered data over the socket. + """ + return False # pragma: no cover + + # @abstractmethod + async def read_from_descriptors(self, r: Readables) -> bool: + """Implementations must now read data over the socket.""" + return False # pragma: no cover diff --git a/proxy/http/exception/__init__.py b/proxy/http/exception/__init__.py index 513d2bd51..68776e923 100644 --- a/proxy/http/exception/__init__.py +++ b/proxy/http/exception/__init__.py @@ -9,9 +9,10 @@ :license: BSD, see LICENSE for more details. """ from .base import HttpProtocolException -from .http_request_rejected import HttpRequestRejected from .proxy_auth_failed import ProxyAuthenticationFailed from .proxy_conn_failed import ProxyConnectionFailed +from .http_request_rejected import HttpRequestRejected + __all__ = [ 'HttpProtocolException', diff --git a/proxy/http/exception/base.py b/proxy/http/exception/base.py index 65138e87b..bd1233bae 100644 --- a/proxy/http/exception/base.py +++ b/proxy/http/exception/base.py @@ -7,18 +7,28 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + http """ -from typing import Optional +from typing import TYPE_CHECKING, Any, Optional -from ..parser import HttpParser + +if TYPE_CHECKING: # pragma: no cover + from ..parser import HttpParser class HttpProtocolException(Exception): - """Top level HttpProtocolException exception class. + """Top level :exc:`HttpProtocolException` exception class. + + All exceptions raised during execution of HTTP request lifecycle MUST + inherit :exc:`HttpProtocolException` base class. Implement + ``response()`` method to optionally return custom response to client. + """ - All exceptions raised during execution of Http request lifecycle MUST - inherit HttpProtocolException base class. Implement response() method - to optionally return custom response to client.""" + def __init__(self, message: Optional[str] = None, **kwargs: Any) -> None: + super().__init__(message or 'Reason unknown') - def response(self, request: HttpParser) -> Optional[memoryview]: + def response(self, request: 'HttpParser') -> Optional[memoryview]: return None # pragma: no cover diff --git a/proxy/http/exception/http_request_rejected.py b/proxy/http/exception/http_request_rejected.py index 46fd9b04a..2b2e7a13b 100644 --- a/proxy/http/exception/http_request_rejected.py +++ b/proxy/http/exception/http_request_rejected.py @@ -8,35 +8,51 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from typing import Optional, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional from .base import HttpProtocolException -from ..parser import HttpParser from ...common.utils import build_http_response +if TYPE_CHECKING: # pragma: no cover + from ..parser import HttpParser + + class HttpRequestRejected(HttpProtocolException): """Generic exception that can be used to reject the client requests. Connections can either be dropped/closed or optionally an HTTP status code can be returned.""" - def __init__(self, - status_code: Optional[int] = None, - reason: Optional[bytes] = None, - headers: Optional[Dict[bytes, bytes]] = None, - body: Optional[bytes] = None): + def __init__( + self, + status_code: Optional[int] = None, + reason: Optional[bytes] = None, + headers: Optional[Dict[bytes, bytes]] = None, + body: Optional[bytes] = None, + **kwargs: Any, + ): self.status_code: Optional[int] = status_code self.reason: Optional[bytes] = reason self.headers: Optional[Dict[bytes, bytes]] = headers self.body: Optional[bytes] = body + klass_name = self.__class__.__name__ + super().__init__( + message='%s %r' % (klass_name, reason) + if reason + else klass_name, + **kwargs, + ) - def response(self, _request: HttpParser) -> Optional[memoryview]: + def response(self, _request: 'HttpParser') -> Optional[memoryview]: if self.status_code: - return memoryview(build_http_response( - status_code=self.status_code, - reason=self.reason, - headers=self.headers, - body=self.body - )) + return memoryview( + build_http_response( + status_code=self.status_code, + reason=self.reason, + headers=self.headers, + body=self.body, + conn_close=True, + ), + ) return None diff --git a/proxy/http/exception/proxy_auth_failed.py b/proxy/http/exception/proxy_auth_failed.py index ae1c6a444..afb2e4048 100644 --- a/proxy/http/exception/proxy_auth_failed.py +++ b/proxy/http/exception/proxy_auth_failed.py @@ -7,28 +7,28 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + auth + http """ +from typing import TYPE_CHECKING, Any + from .base import HttpProtocolException -from ..parser import HttpParser -from ..codes import httpStatusCodes +from ..responses import PROXY_AUTH_FAILED_RESPONSE_PKT -from ...common.constants import PROXY_AGENT_HEADER_VALUE, PROXY_AGENT_HEADER_KEY -from ...common.utils import build_http_response + +if TYPE_CHECKING: # pragma: no cover + from ..parser import HttpParser class ProxyAuthenticationFailed(HttpProtocolException): - """Exception raised when Http Proxy auth is enabled and + """Exception raised when HTTP Proxy auth is enabled and incoming request doesn't present necessary credentials.""" - RESPONSE_PKT = memoryview(build_http_response( - httpStatusCodes.PROXY_AUTH_REQUIRED, - reason=b'Proxy Authentication Required', - headers={ - PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, - b'Proxy-Authenticate': b'Basic', - b'Connection': b'close', - }, - body=b'Proxy Authentication Required')) - - def response(self, _request: HttpParser) -> memoryview: - return self.RESPONSE_PKT + def __init__(self, **kwargs: Any) -> None: + super().__init__(self.__class__.__name__, **kwargs) + + def response(self, _request: 'HttpParser') -> memoryview: + return PROXY_AUTH_FAILED_RESPONSE_PKT diff --git a/proxy/http/exception/proxy_conn_failed.py b/proxy/http/exception/proxy_conn_failed.py index 0cec22427..2001b3360 100644 --- a/proxy/http/exception/proxy_conn_failed.py +++ b/proxy/http/exception/proxy_conn_failed.py @@ -7,32 +7,29 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + conn """ +from typing import TYPE_CHECKING, Any + from .base import HttpProtocolException -from ..parser import HttpParser -from ..codes import httpStatusCodes +from ..responses import BAD_GATEWAY_RESPONSE_PKT -from ...common.constants import PROXY_AGENT_HEADER_VALUE, PROXY_AGENT_HEADER_KEY -from ...common.utils import build_http_response + +if TYPE_CHECKING: # pragma: no cover + from ..parser import HttpParser class ProxyConnectionFailed(HttpProtocolException): - """Exception raised when HttpProxyPlugin is unable to establish connection to upstream server.""" - - RESPONSE_PKT = memoryview(build_http_response( - httpStatusCodes.BAD_GATEWAY, - reason=b'Bad Gateway', - headers={ - PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close' - }, - body=b'Bad Gateway' - )) - - def __init__(self, host: str, port: int, reason: str): + """Exception raised when ``HttpProxyPlugin`` is unable to establish connection to upstream server.""" + + def __init__(self, host: str, port: int, reason: str, **kwargs: Any): self.host: str = host self.port: int = port self.reason: str = reason + super().__init__('%s %s' % (self.__class__.__name__, reason), **kwargs) - def response(self, _request: HttpParser) -> memoryview: - return self.RESPONSE_PKT + def response(self, _request: 'HttpParser') -> memoryview: + return BAD_GATEWAY_RESPONSE_PKT diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 9c7dd90c6..b8d207d56 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -8,402 +8,332 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import socket -import selectors import ssl import time -import contextlib import errno +import socket +import asyncio import logging -from abc import ABC, abstractmethod -from typing import Tuple, List, Union, Optional, Generator, Dict -from uuid import UUID -from .parser import HttpParser, httpParserStates, httpParserTypes +import selectors +from typing import Any, List, Type, Tuple, Optional + +from .parser import HttpParser, httpParserTypes, httpParserStates +from .plugin import HttpProtocolHandlerPlugin from .exception import HttpProtocolException +from .protocols import httpProtocols +from .responses import BAD_REQUEST_RESPONSE_PKT +from ..core.base import BaseTcpServerHandler +from .connection import HttpClientConnection +from ..common.types import Readables, Writables, SelectableEvents +from ..common.constants import DEFAULT_SELECTOR_SELECT_TIMEOUT -from ..common.flags import Flags -from ..common.types import HasFileno -from ..core.threadless import ThreadlessWork -from ..core.event import EventQueue -from ..core.connection import TcpClientConnection logger = logging.getLogger(__name__) -class HttpProtocolHandlerPlugin(ABC): - """Base HttpProtocolHandler Plugin class. - - NOTE: This is an internal plugin and in most cases only useful for core contributors. - If you are looking for proxy server plugins see ``. - - Implements various lifecycle events for an accepted client connection. - Following events are of interest: - - 1. Client Connection Accepted - A new plugin instance is created per accepted client connection. - Add your logic within __init__ constructor for any per connection setup. - 2. Client Request Chunk Received - on_client_data is called for every chunk of data sent by the client. - 3. Client Request Complete - on_request_complete is called once client request has completed. - 4. Server Response Chunk Received - on_response_chunk is called for every chunk received from the server. - 5. Client Connection Closed - Add your logic within `on_client_connection_close` for any per connection teardown. - """ - - def __init__( - self, - uid: UUID, - flags: Flags, - client: TcpClientConnection, - request: HttpParser, - event_queue: EventQueue): - self.uid: UUID = uid - self.flags: Flags = flags - self.client: TcpClientConnection = client - self.request: HttpParser = request - self.event_queue = event_queue - super().__init__() - - def name(self) -> str: - """A unique name for your plugin. - - Defaults to name of the class. This helps plugin developers to directly - access a specific plugin by its name.""" - return self.__class__.__name__ - - @abstractmethod - def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: - return [], [] # pragma: no cover - - @abstractmethod - def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: - return False # pragma: no cover - - @abstractmethod - def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: - return False # pragma: no cover - - @abstractmethod - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: - return raw # pragma: no cover - - @abstractmethod - def on_request_complete(self) -> Union[socket.socket, bool]: - """Called right after client request parser has reached COMPLETE state.""" - return False # pragma: no cover - - @abstractmethod - def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: - """Handle data chunks as received from the server. - - Return optionally modified chunk to return back to client.""" - return chunk # pragma: no cover - - @abstractmethod - def on_client_connection_close(self) -> None: - pass # pragma: no cover - - -class HttpProtocolHandler(ThreadlessWork): +class HttpProtocolHandler(BaseTcpServerHandler[HttpClientConnection]): """HTTP, HTTPS, HTTP2, WebSockets protocol handler. - Accepts `Client` connection object and manages HttpProtocolHandlerPlugin invocations. + Accepts `Client` connection and delegates to HttpProtocolHandlerPlugin. """ - def __init__(self, client: TcpClientConnection, - flags: Optional[Flags] = None, - event_queue: Optional[EventQueue] = None, - uid: Optional[UUID] = None): - super().__init__(client, flags, event_queue, uid) - + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) self.start_time: float = time.time() self.last_activity: float = self.start_time - self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER) - self.response: HttpParser = HttpParser(httpParserTypes.RESPONSE_PARSER) - self.selector = selectors.DefaultSelector() - self.client: TcpClientConnection = client - self.plugins: Dict[str, HttpProtocolHandlerPlugin] = {} + self.request: HttpParser = HttpParser( + httpParserTypes.REQUEST_PARSER, + enable_proxy_protocol=self.flags.enable_proxy_protocol, + ) + self.selector: Optional[selectors.DefaultSelector] = None + if not self.flags.threadless: + self.selector = selectors.DefaultSelector() + self.plugin: Optional[HttpProtocolHandlerPlugin] = None + + ## + # initialize, is_inactive, shutdown, get_events, handle_events + # overrides Work class definitions. + ## + + @staticmethod + def create(*args: Any) -> HttpClientConnection: # pragma: no cover + return HttpClientConnection(*args) def initialize(self) -> None: - """Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins.""" - conn = self.optionally_wrap_socket(self.client.connection) - conn.setblocking(False) - if self.flags.encryption_enabled(): - self.client = TcpClientConnection(conn=conn, addr=self.client.addr) - if b'HttpProtocolHandlerPlugin' in self.flags.plugins: - for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']: - instance = klass( - self.uid, - self.flags, - self.client, - self.request, - self.event_queue) - self.plugins[instance.name()] = instance - logger.debug('Handling connection %r' % self.client.connection) + super().initialize() + if self._encryption_enabled(): + self.work = HttpClientConnection( + conn=self.work.connection, + addr=self.work.addr, + ) def is_inactive(self) -> bool: - if not self.client.has_buffer() and \ - self.connection_inactive_for() > self.flags.timeout: + if not self.work.has_buffer() and \ + self._connection_inactive_for() > self.flags.timeout: return True return False - def get_events(self) -> Dict[socket.socket, int]: - events: Dict[socket.socket, int] = { - self.client.connection: selectors.EVENT_READ - } - if self.client.has_buffer(): - events[self.client.connection] |= selectors.EVENT_WRITE + def shutdown(self) -> None: + try: + # Flush pending buffer in threaded mode only. + # + # For threadless mode, BaseTcpServerHandler implements + # the must_flush_before_shutdown logic automagically. + if self.selector and self.work.has_buffer(): + self._flush() + # Invoke plugin.on_client_connection_close + if self.plugin: + self.plugin.on_client_connection_close() + logger.debug( + 'Closing client connection %s has buffer %s' % + (self.work.address, self.work.has_buffer()), + ) + conn = self.work.connection + # Unwrap if wrapped before shutdown. + if self._encryption_enabled() and \ + isinstance(self.work.connection, ssl.SSLSocket): + conn = self.work.connection.unwrap() + conn.shutdown(socket.SHUT_WR) + logger.debug('Client connection shutdown successful') + except OSError: + pass + finally: + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data + # could lead to an immediate reset being sent. + # + # "A host MAY implement a 'half-duplex' TCP close sequence, so that an application + # that has called CLOSE cannot continue to read data from the connection. + # If such a host issues a CLOSE call while received data is still pending in TCP, + # or if new data is received after CLOSE is called, its TCP SHOULD send a RST to + # show that data was lost." + # + self.work.connection.close() + logger.debug('Client connection closed') + super().shutdown() + async def get_events(self) -> SelectableEvents: + # Get default client events + events: SelectableEvents = await super().get_events() # HttpProtocolHandlerPlugin.get_descriptors - for plugin in self.plugins.values(): - plugin_read_desc, plugin_write_desc = plugin.get_descriptors() - for r in plugin_read_desc: - if r not in events: - events[r] = selectors.EVENT_READ + if self.plugin: + plugin_read_desc, plugin_write_desc = await self.plugin.get_descriptors() + for rfileno in plugin_read_desc: + if rfileno not in events: + events[rfileno] = selectors.EVENT_READ else: - events[r] |= selectors.EVENT_READ - for w in plugin_write_desc: - if w not in events: - events[w] = selectors.EVENT_WRITE + events[rfileno] |= selectors.EVENT_READ + for wfileno in plugin_write_desc: + if wfileno not in events: + events[wfileno] = selectors.EVENT_WRITE else: - events[w] |= selectors.EVENT_WRITE - + events[wfileno] |= selectors.EVENT_WRITE return events - def handle_events( + # We override super().handle_events and never call it + async def handle_events( self, - readables: List[Union[int, HasFileno]], - writables: List[Union[int, HasFileno]]) -> bool: - """Returns True if proxy must teardown.""" + readables: Readables, + writables: Writables, + ) -> bool: + """Returns True if proxy must tear down.""" # Flush buffer for ready to write sockets - teardown = self.handle_writables(writables) + teardown = await self.handle_writables(writables) if teardown: return True - # Invoke plugin.write_to_descriptors - for plugin in self.plugins.values(): - teardown = plugin.write_to_descriptors(writables) + if self.plugin: + teardown = await self.plugin.write_to_descriptors(writables) if teardown: return True - # Read from ready to read sockets - teardown = self.handle_readables(readables) + teardown = await self.handle_readables(readables) if teardown: return True - # Invoke plugin.read_from_descriptors - for plugin in self.plugins.values(): - teardown = plugin.read_from_descriptors(readables) + if self.plugin: + teardown = await self.plugin.read_from_descriptors(readables) if teardown: return True - return False - def shutdown(self) -> None: - try: - # Flush pending buffer if any - self.flush() - - # Invoke plugin.on_client_connection_close - for plugin in self.plugins.values(): - plugin.on_client_connection_close() - - logger.debug( - 'Closing client connection %r ' - 'at address %r has buffer %s' % - (self.client.connection, self.client.addr, self.client.has_buffer())) - - conn = self.client.connection - # Unwrap if wrapped before shutdown. - if self.flags.encryption_enabled() and \ - isinstance(self.client.connection, ssl.SSLSocket): - conn = self.client.connection.unwrap() - conn.shutdown(socket.SHUT_WR) - logger.debug('Client connection shutdown successful') - except OSError: - pass - finally: - self.client.connection.close() - logger.debug('Client connection closed') - super().shutdown() - - def optionally_wrap_socket( - self, conn: socket.socket) -> Union[ssl.SSLSocket, socket.socket]: - """Attempts to wrap accepted client connection using provided certificates. - - Shutdown and closes client connection upon error. - """ - if self.flags.encryption_enabled(): - ctx = ssl.create_default_context( - ssl.Purpose.CLIENT_AUTH) - ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 - ctx.verify_mode = ssl.CERT_NONE - assert self.flags.keyfile and self.flags.certfile - ctx.load_cert_chain( - certfile=self.flags.certfile, - keyfile=self.flags.keyfile) - conn = ctx.wrap_socket( - conn, - server_side=True, - ) - return conn - - def connection_inactive_for(self) -> float: - return time.time() - self.last_activity - - def flush(self) -> None: - if not self.client.has_buffer(): - return + def handle_data(self, data: memoryview) -> Optional[bool]: + """Handles incoming data from client.""" + if data is None: + logger.debug('Client closed connection, tearing down...') + self.work.closed = True + return True try: - self.selector.register( - self.client.connection, - selectors.EVENT_WRITE) - while self.client.has_buffer(): - ev: List[Tuple[selectors.SelectorKey, int] - ] = self.selector.select(timeout=1) - if len(ev) == 0: - continue - self.client.flush() - except BrokenPipeError: - pass - finally: - self.selector.unregister(self.client.connection) + # We don't parse incoming data any further after 1st HTTP request packet. + # + # Plugins can utilize on_client_data for such cases and + # apply custom logic to handle request data sent after 1st + # valid request. + if self.request.state != httpParserStates.COMPLETE: + if self._parse_first_request(data): + return True + # HttpProtocolHandlerPlugin.on_client_data + # Can raise HttpProtocolException to tear down the connection + elif self.plugin: + self.plugin.on_client_data(data) + except HttpProtocolException as e: + logger.info('HttpProtocolException: %s' % e) + response: Optional[memoryview] = e.response(self.request) + if response: + self.work.queue(response) + return True + return False - def handle_writables(self, writables: List[Union[int, HasFileno]]) -> bool: - if self.client.has_buffer() and self.client.connection in writables: - logger.debug('Client is ready for writes, flushing buffer') + async def handle_writables(self, writables: Writables) -> bool: + if self.work.connection.fileno() in writables and self.work.has_buffer(): + logger.debug('Client is write ready, flushing...') self.last_activity = time.time() - # TODO(abhinavsingh): This hook could just reside within server recv block # instead of invoking when flushed to client. + # # Invoke plugin.on_response_chunk - chunk = self.client.buffer - for plugin in self.plugins.values(): - chunk = plugin.on_response_chunk(chunk) - if chunk is None: - break - + chunk = self.work.buffer + if self.plugin: + chunk = self.plugin.on_response_chunk(chunk) try: - self.client.flush() + # Call super() for client flush + teardown = await super().handle_writables(writables) + if teardown: + return True except BrokenPipeError: - logger.error( - 'BrokenPipeError when flushing buffer for client') + logger.warning( # pragma: no cover + 'BrokenPipeError when flushing buffer for client', + ) return True except OSError: - logger.error('OSError when flushing buffer to client') + logger.warning( # pragma: no cover + 'OSError when flushing buffer to client', + ) return True return False - def handle_readables(self, readables: List[Union[int, HasFileno]]) -> bool: - if self.client.connection in readables: - logger.debug('Client is ready for reads, reading') + async def handle_readables(self, readables: Readables) -> bool: + if self.work.connection.fileno() in readables: + logger.debug('Client is read ready, receiving...') self.last_activity = time.time() try: - client_data = self.client.recv(self.flags.client_recvbuf_size) + teardown = await super().handle_readables(readables) + if teardown: + return teardown except ssl.SSLWantReadError: # Try again later logger.warning( - 'SSLWantReadError encountered while reading from client, will retry ...') + 'SSLWantReadError encountered while reading from client, will retry ...', + ) return False except socket.error as e: if e.errno == errno.ECONNRESET: - logger.warning('%r' % e) + # Most requests for mobile devices will end up + # with client closed connection. Using `debug` + # here to avoid flooding the logs. + logger.debug('%r' % e) else: - logger.exception( - 'Exception while receiving from %s connection %r with reason %r' % - (self.client.tag, self.client.connection, e)) + logger.warning( + 'Exception when receiving from %s connection#%d with reason %r' % + (self.work.tag, self.work.connection.fileno(), e), + exc_info=True, + ) return True + return False - if client_data is None: - logger.debug('Client closed connection, tearing down...') - self.client.closed = True - return True + ## + # Internal methods + ## - try: - # HttpProtocolHandlerPlugin.on_client_data - # Can raise HttpProtocolException to teardown the connection - plugin_index = 0 - plugins = list(self.plugins.values()) - while plugin_index < len(plugins) and client_data: - client_data = plugins[plugin_index].on_client_data( - client_data) - if client_data is None: - break - plugin_index += 1 + def _initialize_plugin( + self, + klass: Type['HttpProtocolHandlerPlugin'], + ) -> HttpProtocolHandlerPlugin: + """Initializes passed HTTP protocol handler plugin class.""" + return klass( + self.uid, + self.flags, + self.work, + self.request, + self.event_queue, + self.upstream_conn_pool, + ) + + def _discover_plugin_klass(self, protocol: int) -> Optional[Type['HttpProtocolHandlerPlugin']]: + """Discovers and return matching HTTP handler plugin matching protocol.""" + if b'HttpProtocolHandlerPlugin' in self.flags.plugins: + for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']: + k: Type['HttpProtocolHandlerPlugin'] = klass + if protocol in k.protocols(): + return k + return None - # Don't parse request any further after 1st request has completed. - # This specially does happen for pipeline requests. - # Plugins can utilize on_client_data for such cases and - # apply custom logic to handle request data sent after 1st - # valid request. - if client_data and self.request.state != httpParserStates.COMPLETE: - # Parse http request - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant - self.request.parse(client_data.tobytes()) - if self.request.state == httpParserStates.COMPLETE: - # Invoke plugin.on_request_complete - for plugin in self.plugins.values(): - upgraded_sock = plugin.on_request_complete() - if isinstance(upgraded_sock, ssl.SSLSocket): - logger.debug( - 'Updated client conn to %s', upgraded_sock) - self.client._conn = upgraded_sock - for plugin_ in self.plugins.values(): - if plugin_ != plugin: - plugin_.client._conn = upgraded_sock - elif isinstance(upgraded_sock, bool) and upgraded_sock is True: - return True - except HttpProtocolException as e: - logger.debug( - 'HttpProtocolException type raised') - response: Optional[memoryview] = e.response(self.request) - if response: - self.client.queue(response) - return True + def _parse_first_request(self, data: memoryview) -> bool: + # Parse http request + try: + self.request.parse(data) + except HttpProtocolException as e: # noqa: WPS329 + self.work.queue(BAD_REQUEST_RESPONSE_PKT) + raise e + except Exception as e: + self.work.queue(BAD_REQUEST_RESPONSE_PKT) + raise HttpProtocolException( + 'Error when parsing request: %r' % data.tobytes(), + ) from e + if not self.request.is_complete: + return False + # Bail out if http protocol is unknown + if self.request.http_handler_protocol == httpProtocols.UNKNOWN: + self.work.queue(BAD_REQUEST_RESPONSE_PKT) + return True + # Discover which HTTP handler plugin is capable of + # handling the current incoming request + klass = self._discover_plugin_klass( + self.request.http_handler_protocol, + ) + if klass is None: + # No matching protocol class found. + # Return bad request response and + # close the connection. + self.work.queue(BAD_REQUEST_RESPONSE_PKT) + return True + assert klass is not None + self.plugin = self._initialize_plugin(klass) + # Invoke plugin.on_request_complete + output = self.plugin.on_request_complete() + if isinstance(output, bool): + return output + assert isinstance(output, ssl.SSLSocket) + logger.debug( + 'Updated client conn to %s', output, + ) + self.work._conn = output return False - @contextlib.contextmanager - def selected_events(self) -> \ - Generator[Tuple[List[Union[int, HasFileno]], - List[Union[int, HasFileno]]], - None, None]: - events = self.get_events() - for fd in events: - self.selector.register(fd, events[fd]) - ev = self.selector.select(timeout=1) - readables = [] - writables = [] - for key, mask in ev: - if mask & selectors.EVENT_READ: - readables.append(key.fileobj) - if mask & selectors.EVENT_WRITE: - writables.append(key.fileobj) - yield (readables, writables) - for fd in events.keys(): - self.selector.unregister(fd) + def _connection_inactive_for(self) -> float: + return time.time() - self.last_activity - def run_once(self) -> bool: - with self.selected_events() as (readables, writables): - teardown = self.handle_events(readables, writables) - if teardown: - return True - return False + ## + # run() and _run_once() are here to maintain backward compatibility + # with threaded mode. These methods are only called when running + # in threaded mode. + ## def run(self) -> None: + """run() method is not used when in --threadless mode. + + This is here just to maintain backward compatibility with threaded mode. + """ + loop = asyncio.new_event_loop() try: self.initialize() while True: - # Teardown if client buffer is empty and connection is inactive + # Tear down if client buffer is empty and connection is inactive if self.is_inactive(): logger.debug( 'Client buffer is empty and maximum inactivity has reached ' - 'between client and server connection, tearing down...') + 'between client and server connection, tearing down...', + ) break - teardown = self.run_once() - if teardown: + if loop.run_until_complete(self._run_once()): break except KeyboardInterrupt: # pragma: no cover pass @@ -412,6 +342,60 @@ def run(self) -> None: except Exception as e: logger.exception( 'Exception while handling connection %r' % - self.client.connection, exc_info=e) + self.work.connection, exc_info=e, + ) finally: self.shutdown() + if self.selector: + self.selector.close() + loop.close() + + async def _run_once(self) -> bool: + events, readables, writables = await self._selected_events() + try: + return await self.handle_events(readables, writables) + finally: + assert self.selector + # TODO: Like Threadless we should not unregister + # work fds repeatedly. + for fd in events: + self.selector.unregister(fd) + + # FIXME: Returning events is only necessary because we cannot use async context manager + # for < Python 3.8. As a reason, this method is no longer a context manager and caller + # is responsible for unregistering the descriptors. + async def _selected_events(self) -> Tuple[SelectableEvents, Readables, Writables]: + assert self.selector + events = await self.get_events() + for fd in events: + self.selector.register(fd, events[fd]) + ev = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT) + readables = [] + writables = [] + for key, mask in ev: + if mask & selectors.EVENT_READ: + readables.append(key.fd) + if mask & selectors.EVENT_WRITE: + writables.append(key.fd) + return (events, readables, writables) + + def _flush(self) -> None: + assert self.selector + logger.debug('Flushing pending data') + try: + self.selector.register( + self.work.connection, + selectors.EVENT_WRITE, + ) + while self.work.has_buffer(): + logging.debug('Waiting for client read ready') + ev: List[ + Tuple[selectors.SelectorKey, int] + ] = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT) + if len(ev) == 0: + continue + self.work.flush(self.flags.max_sendbuf_size) + except BrokenPipeError: + pass + finally: + self.selector.unregister(self.work.connection) diff --git a/proxy/http/headers.py b/proxy/http/headers.py new file mode 100644 index 000000000..d042067c0 --- /dev/null +++ b/proxy/http/headers.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable +""" +from typing import NamedTuple + + +# Ref: https://www.iana.org/assignments/http-methods/http-methods.xhtml +HttpHeaders = NamedTuple( + 'HttpHeaders', [ + ('PROXY_AUTHORIZATION', bytes), + ('PROXY_CONNECTION', bytes), + ], +) + +httpHeaders = HttpHeaders( + b'proxy-authorization', + b'proxy-connection', +) diff --git a/proxy/http/methods.py b/proxy/http/methods.py index 63b8cd1e9..30c9b71a7 100644 --- a/proxy/http/methods.py +++ b/proxy/http/methods.py @@ -7,29 +7,100 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable """ from typing import NamedTuple -HttpMethods = NamedTuple('HttpMethods', [ - ('GET', bytes), - ('HEAD', bytes), - ('POST', bytes), - ('PUT', bytes), - ('DELETE', bytes), - ('CONNECT', bytes), - ('OPTIONS', bytes), - ('TRACE', bytes), - ('PATCH', bytes), -]) +# Ref: https://www.iana.org/assignments/http-methods/http-methods.xhtml +HttpMethods = NamedTuple( + 'HttpMethods', [ + ('ACL', bytes), + ('BASELINE_CONTROL', bytes), + ('BIND', bytes), + ('CHECKIN', bytes), + ('CHECKOUT', bytes), + ('CONNECT', bytes), + ('COPY', bytes), + ('DELETE', bytes), + ('GET', bytes), + ('HEAD', bytes), + ('LABEL', bytes), + ('LINK', bytes), + ('LOCK', bytes), + ('MERGE', bytes), + ('MKACTIVITY', bytes), + ('MKCALENDAR', bytes), + ('MKCOL', bytes), + ('MKREDIRECTREF', bytes), + ('MKWORKSPACE', bytes), + ('MOVE', bytes), + ('OPTIONS', bytes), + ('ORDERPATCH', bytes), + ('PATCH', bytes), + ('POST', bytes), + ('PRI', bytes), + ('PROPFIND', bytes), + ('PROPPATCH', bytes), + ('PUT', bytes), + ('REBIND', bytes), + ('REPORT', bytes), + ('SEARCH', bytes), + ('TRACE', bytes), + ('UNBIND', bytes), + ('UNCHECKOUT', bytes), + ('UNLINK', bytes), + ('UNLOCK', bytes), + ('UPDATE', bytes), + ('UPDATEREDIRECTREF', bytes), + ('VERSION_CONTROL', bytes), + ('STAR', bytes), + ], +) + httpMethods = HttpMethods( + b'ACL', + b'BASELINE-CONTROL', + b'BIND', + b'CHECKIN', + b'CHECKOUT', + b'CONNECT', + b'COPY', + b'DELETE', b'GET', b'HEAD', + b'LABEL', + b'LINK', + b'LOCK', + b'MERGE', + b'MKACTIVITY', + b'MKCALENDAR', + b'MKCOL', + b'MKREDIRECTREF', + b'MKWORKSPACE', + b'MOVE', + b'OPTIONS', + b'ORDERPATCH', + b'PATCH', b'POST', + b'PRI', + b'PROPFIND', + b'PROPPATCH', b'PUT', - b'DELETE', - b'CONNECT', - b'OPTIONS', + b'REBIND', + b'REPORT', + b'SEARCH', b'TRACE', - b'PATCH', + b'UNBIND', + b'UNCHECKOUT', + b'UNLINK', + b'UNLOCK', + b'UPDATE', + b'UPDATEREDIRECTREF', + b'VERSION-CONTROL', + b'*', ) diff --git a/proxy/http/parser.py b/proxy/http/parser.py deleted file mode 100644 index 63c62b8b9..000000000 --- a/proxy/http/parser.py +++ /dev/null @@ -1,265 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from urllib import parse as urlparse -from typing import TypeVar, NamedTuple, Optional, Dict, Type, Tuple, List - -from .methods import httpMethods -from .chunk_parser import ChunkParser, chunkParserStates - -from ..common.constants import DEFAULT_DISABLE_HEADERS, COLON, CRLF, WHITESPACE, HTTP_1_1, DEFAULT_HTTP_PORT -from ..common.utils import build_http_request, find_http_line, text_ - - -HttpParserStates = NamedTuple('HttpParserStates', [ - ('INITIALIZED', int), - ('LINE_RCVD', int), - ('RCVING_HEADERS', int), - ('HEADERS_COMPLETE', int), - ('RCVING_BODY', int), - ('COMPLETE', int), -]) -httpParserStates = HttpParserStates(1, 2, 3, 4, 5, 6) - -HttpParserTypes = NamedTuple('HttpParserTypes', [ - ('REQUEST_PARSER', int), - ('RESPONSE_PARSER', int), -]) -httpParserTypes = HttpParserTypes(1, 2) - - -T = TypeVar('T', bound='HttpParser') - - -class HttpParser: - """HTTP request/response parser.""" - - def __init__(self, parser_type: int) -> None: - self.type: int = parser_type - self.state: int = httpParserStates.INITIALIZED - - # Total size of raw bytes passed for parsing - self.total_size: int = 0 - - # Buffer to hold unprocessed bytes - self.buffer: bytes = b'' - - self.headers: Dict[bytes, Tuple[bytes, bytes]] = dict() - self.body: Optional[bytes] = None - - self.method: Optional[bytes] = None - self.url: Optional[urlparse.SplitResultBytes] = None - self.code: Optional[bytes] = None - self.reason: Optional[bytes] = None - self.version: Optional[bytes] = None - - self.chunk_parser: Optional[ChunkParser] = None - - # This cleans up developer APIs as Python urlparse.urlsplit behaves differently - # for incoming proxy request and incoming web request. Web request is the one - # which is broken. - self.host: Optional[bytes] = None - self.port: Optional[int] = None - self.path: Optional[bytes] = None - - @classmethod - def request(cls: Type[T], raw: bytes) -> T: - parser = cls(httpParserTypes.REQUEST_PARSER) - parser.parse(raw) - return parser - - @classmethod - def response(cls: Type[T], raw: bytes) -> T: - parser = cls(httpParserTypes.RESPONSE_PARSER) - parser.parse(raw) - return parser - - def header(self, key: bytes) -> bytes: - if key.lower() not in self.headers: - raise KeyError('%s not found in headers', text_(key)) - return self.headers[key.lower()][1] - - def has_header(self, key: bytes) -> bool: - return key.lower() in self.headers - - def add_header(self, key: bytes, value: bytes) -> None: - self.headers[key.lower()] = (key, value) - - def add_headers(self, headers: List[Tuple[bytes, bytes]]) -> None: - for (key, value) in headers: - self.add_header(key, value) - - def del_header(self, header: bytes) -> None: - if header.lower() in self.headers: - del self.headers[header.lower()] - - def del_headers(self, headers: List[bytes]) -> None: - for key in headers: - self.del_header(key.lower()) - - def set_url(self, url: bytes) -> None: - # Work around with urlsplit semantics. - # - # For CONNECT requests, request line contains - # upstream_host:upstream_port which is not complaint - # with urlsplit, which expects a fully qualified url. - if self.method == b'CONNECT': - url = b'https://' + url - self.url = urlparse.urlsplit(url) - self.set_line_attributes() - - def set_line_attributes(self) -> None: - if self.type == httpParserTypes.REQUEST_PARSER: - if self.method == httpMethods.CONNECT and self.url: - self.host = self.url.hostname - self.port = 443 if self.url.port is None else self.url.port - elif self.url: - self.host, self.port = self.url.hostname, self.url.port \ - if self.url.port else DEFAULT_HTTP_PORT - else: - raise KeyError( - 'Invalid request. Method: %r, Url: %r' % - (self.method, self.url)) - self.path = self.build_path() - - def is_chunked_encoded(self) -> bool: - return b'transfer-encoding' in self.headers and \ - self.headers[b'transfer-encoding'][1].lower() == b'chunked' - - def body_expected(self) -> bool: - return (b'content-length' in self.headers and - int(self.header(b'content-length')) > 0) or \ - self.is_chunked_encoded() - - def parse(self, raw: bytes) -> None: - """Parses Http request out of raw bytes. - - Check HttpParser state after parse has successfully returned.""" - self.total_size += len(raw) - raw = self.buffer + raw - self.buffer = b'' - - more = True if len(raw) > 0 else False - while more and self.state != httpParserStates.COMPLETE: - if self.state in ( - httpParserStates.HEADERS_COMPLETE, - httpParserStates.RCVING_BODY): - if b'content-length' in self.headers: - self.state = httpParserStates.RCVING_BODY - if self.body is None: - self.body = b'' - total_size = int(self.header(b'content-length')) - received_size = len(self.body) - self.body += raw[:total_size - received_size] - if self.body and \ - len(self.body) == int(self.header(b'content-length')): - self.state = httpParserStates.COMPLETE - more, raw = len(raw) > 0, raw[total_size - received_size:] - elif self.is_chunked_encoded(): - if not self.chunk_parser: - self.chunk_parser = ChunkParser() - raw = self.chunk_parser.parse(raw) - if self.chunk_parser.state == chunkParserStates.COMPLETE: - self.body = self.chunk_parser.body - self.state = httpParserStates.COMPLETE - more = False - else: - raise NotImplementedError('Parser shouldn\'t have reached here') - else: - more, raw = self.process(raw) - self.buffer = raw - - def process(self, raw: bytes) -> Tuple[bool, bytes]: - """Returns False when no CRLF could be found in received bytes.""" - line, raw = find_http_line(raw) - if line is None: - return False, raw - - if self.state == httpParserStates.INITIALIZED: - self.process_line(line) - self.state = httpParserStates.LINE_RCVD - elif self.state in (httpParserStates.LINE_RCVD, httpParserStates.RCVING_HEADERS): - if self.state == httpParserStates.LINE_RCVD: - # LINE_RCVD state is equivalent to RCVING_HEADERS - self.state = httpParserStates.RCVING_HEADERS - if line.strip() == b'': # Blank line received. - self.state = httpParserStates.HEADERS_COMPLETE - else: - self.process_header(line) - - # When server sends a response line without any header or body e.g. - # HTTP/1.1 200 Connection established\r\n\r\n - if self.state == httpParserStates.LINE_RCVD and \ - self.type == httpParserTypes.RESPONSE_PARSER and \ - raw == CRLF: - self.state = httpParserStates.COMPLETE - elif self.state == httpParserStates.HEADERS_COMPLETE and \ - not self.body_expected() and \ - raw == b'': - self.state = httpParserStates.COMPLETE - - return len(raw) > 0, raw - - def process_line(self, raw: bytes) -> None: - line = raw.split(WHITESPACE) - if self.type == httpParserTypes.REQUEST_PARSER: - self.method = line[0].upper() - self.set_url(line[1]) - self.version = line[2] - else: - self.version = line[0] - self.code = line[1] - self.reason = WHITESPACE.join(line[2:]) - - def process_header(self, raw: bytes) -> None: - parts = raw.split(COLON) - key = parts[0].strip() - value = COLON.join(parts[1:]).strip() - self.add_headers([(key, value)]) - - def build_path(self) -> bytes: - if not self.url: - return b'/None' - url = self.url.path - if url == b'': - url = b'/' - if not self.url.query == b'': - url += b'?' + self.url.query - if not self.url.fragment == b'': - url += b'#' + self.url.fragment - return url - - def build(self, disable_headers: Optional[List[bytes]] = None) -> bytes: - assert self.method and self.version and self.path - if disable_headers is None: - disable_headers = DEFAULT_DISABLE_HEADERS - body: Optional[bytes] = ChunkParser.to_chunks(self.body) \ - if self.is_chunked_encoded() and self.body else \ - self.body - return build_http_request( - self.method, self.path, self.version, - headers={} if not self.headers else {self.headers[k][0]: self.headers[k][1] for k in self.headers if - k.lower() not in disable_headers}, - body=body - ) - - def has_upstream_server(self) -> bool: - """Host field SHOULD be None for incoming local WebServer requests.""" - return True if self.host is not None else False - - def is_http_1_1_keep_alive(self) -> bool: - return self.version == HTTP_1_1 and \ - (not self.has_header(b'Connection') or - self.header(b'Connection').lower() == b'keep-alive') - - def is_connection_upgrade(self) -> bool: - return self.version == HTTP_1_1 and \ - self.has_header(b'Connection') and \ - self.has_header(b'Upgrade') diff --git a/proxy/http/parser/__init__.py b/proxy/http/parser/__init__.py new file mode 100644 index 000000000..d19766652 --- /dev/null +++ b/proxy/http/parser/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + Submodules +""" +from .chunk import ChunkParser, chunkParserStates +from .types import httpParserTypes, httpParserStates +from .parser import HttpParser +from .protocol import PROXY_PROTOCOL_V2_SIGNATURE, ProxyProtocol + + +__all__ = [ + 'HttpParser', + 'httpParserTypes', + 'httpParserStates', + 'ChunkParser', + 'chunkParserStates', + 'ProxyProtocol', + 'PROXY_PROTOCOL_V2_SIGNATURE', +] diff --git a/proxy/http/chunk_parser.py b/proxy/http/parser/chunk.py similarity index 81% rename from proxy/http/chunk_parser.py rename to proxy/http/parser/chunk.py index 2b9b72c42..eb3579895 100644 --- a/proxy/http/chunk_parser.py +++ b/proxy/http/parser/chunk.py @@ -8,17 +8,19 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from typing import NamedTuple, Tuple, List, Optional +from typing import List, Tuple, Optional, NamedTuple -from ..common.utils import bytes_, find_http_line -from ..common.constants import CRLF, DEFAULT_BUFFER_SIZE +from ...common.utils import bytes_, find_http_line +from ...common.constants import CRLF, DEFAULT_BUFFER_SIZE -ChunkParserStates = NamedTuple('ChunkParserStates', [ - ('WAITING_FOR_SIZE', int), - ('WAITING_FOR_DATA', int), - ('COMPLETE', int), -]) +ChunkParserStates = NamedTuple( + 'ChunkParserStates', [ + ('WAITING_FOR_SIZE', int), + ('WAITING_FOR_DATA', int), + ('COMPLETE', int), + ], +) chunkParserStates = ChunkParserStates(1, 2, 3) @@ -32,13 +34,13 @@ def __init__(self) -> None: # Expected size of next following chunk self.size: Optional[int] = None - def parse(self, raw: bytes) -> bytes: - more = True if len(raw) > 0 else False + def parse(self, raw: memoryview) -> memoryview: + more = len(raw) > 0 while more and self.state != chunkParserStates.COMPLETE: - more, raw = self.process(raw) + more, raw = self.process(raw.tobytes()) return raw - def process(self, raw: bytes) -> Tuple[bool, bytes]: + def process(self, raw: bytes) -> Tuple[bool, memoryview]: if self.state == chunkParserStates.WAITING_FOR_SIZE: # Consume prior chunk in buffer # in case chunk size without CRLF was received @@ -67,7 +69,7 @@ def process(self, raw: bytes) -> Tuple[bool, bytes]: self.state = chunkParserStates.WAITING_FOR_SIZE self.chunk = b'' self.size = None - return len(raw) > 0, raw + return len(raw) > 0, memoryview(raw) @staticmethod def to_chunks(raw: bytes, chunk_size: int = DEFAULT_BUFFER_SIZE) -> bytes: diff --git a/proxy/http/parser/parser.py b/proxy/http/parser/parser.py new file mode 100644 index 000000000..91a0cbbc9 --- /dev/null +++ b/proxy/http/parser/parser.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http +""" +import gzip +from typing import Dict, List, Type, Tuple, TypeVar, Optional + +from ..url import Url +from .chunk import ChunkParser, chunkParserStates +from .types import httpParserTypes, httpParserStates +from ..methods import httpMethods +from .protocol import ProxyProtocol +from ..exception import HttpProtocolException +from ..protocols import httpProtocols +from ...common.flag import flags +from ...common.utils import ( + text_, bytes_, build_http_request, build_http_response, +) +from ...common.constants import ( + CRLF, COLON, SLASH, HTTP_1_0, HTTP_1_1, WHITESPACE, DEFAULT_HTTP_PORT, + DEFAULT_DISABLE_HEADERS, DEFAULT_ENABLE_PROXY_PROTOCOL, +) + + +flags.add_argument( + '--enable-proxy-protocol', + action='store_true', + default=DEFAULT_ENABLE_PROXY_PROTOCOL, + help='Default: ' + str(DEFAULT_ENABLE_PROXY_PROTOCOL) + '. ' + + 'If used, will enable proxy protocol. ' + + 'Only version 1 is currently supported.', +) + + +T = TypeVar('T', bound='HttpParser') + + +class HttpParser: + """HTTP request/response parser. + + TODO: Make me zero-copy by using :class:`memoryview`. + Currently due to chunk/buffer handling we + are not able to utilize :class:`memoryview` + efficiently. + + For this to happen we must store ``buffer`` + as ``List[memoryview]`` instead of raw bytes and + update parser to work accordingly. + """ + + def __init__( + self, parser_type: int, + enable_proxy_protocol: int = DEFAULT_ENABLE_PROXY_PROTOCOL, + ) -> None: + self.state: int = httpParserStates.INITIALIZED + self.type: int = parser_type + self.protocol: Optional[ProxyProtocol] = None + if enable_proxy_protocol: + assert self.type == httpParserTypes.REQUEST_PARSER + self.protocol = ProxyProtocol() + # Request attributes + self.host: Optional[bytes] = None + self.port: Optional[int] = None + self.path: Optional[bytes] = None + self.method: Optional[bytes] = None + # Response attributes + self.code: Optional[bytes] = None + self.reason: Optional[bytes] = None + self.version: Optional[bytes] = None + # Total size of raw bytes passed for parsing + self.total_size: int = 0 + # Buffer to hold unprocessed bytes + self.buffer: Optional[memoryview] = None + # Internal headers data structure: + # - Keys are lower case header names. + # - Values are 2-tuple containing original + # header and it's value as received. + self.headers: Optional[Dict[bytes, Tuple[bytes, bytes]]] = None + self.body: Optional[bytes] = None + self.chunk: Optional[ChunkParser] = None + # Internal request line as a url structure + self._url: Optional[Url] = None + # Deduced states from the packet + self._is_chunked_encoded: bool = False + self._content_expected: bool = False + self._is_https_tunnel: bool = False + + @classmethod + def request( + cls: Type[T], + raw: bytes, + enable_proxy_protocol: int = DEFAULT_ENABLE_PROXY_PROTOCOL, + ) -> T: + parser = cls( + httpParserTypes.REQUEST_PARSER, + enable_proxy_protocol=enable_proxy_protocol, + ) + parser.parse(memoryview(raw)) + return parser + + @classmethod + def response(cls: Type[T], raw: bytes) -> T: + parser = cls(httpParserTypes.RESPONSE_PARSER) + parser.parse(memoryview(raw)) + return parser + + def header(self, key: bytes) -> bytes: + """Convenient method to return original header value from internal data structure.""" + if self.headers is None or key.lower() not in self.headers: + raise KeyError('%s not found in headers' % text_(key)) + return self.headers[key.lower()][1] + + def has_header(self, key: bytes) -> bool: + """Returns true if header key was found in payload.""" + if self.headers is None: + return False + return key.lower() in self.headers + + def add_header(self, key: bytes, value: bytes) -> bytes: + """Add/Update a header to internal data structure. + + Returns key with which passed (key, value) tuple is available.""" + if self.headers is None: + self.headers = {} + k = key.lower() + # k = key + self.headers[k] = (key, value) + return k + + def add_headers(self, headers: List[Tuple[bytes, bytes]]) -> None: + """Add/Update multiple headers to internal data structure""" + for (key, value) in headers: + self.add_header(key, value) + + def del_header(self, header: bytes) -> None: + """Delete a header from internal data structure.""" + if self.headers and header.lower() in self.headers: + del self.headers[header.lower()] + + def del_headers(self, headers: List[bytes]) -> None: + """Delete headers from internal data structure.""" + for key in headers: + self.del_header(key.lower()) + + def set_url(self, url: bytes, allowed_url_schemes: Optional[List[bytes]] = None) -> None: + """Given a request line, parses it and sets line attributes a.k.a. host, port, path.""" + self._url = Url.from_bytes( + url, allowed_url_schemes=allowed_url_schemes, + ) + self._set_line_attributes() + + def update_body(self, body: bytes, content_type: bytes) -> None: + """This method must be used to update body after HTTP packet has been parsed. + + Along with updating the body, this method also respects original + request content encoding, transfer encoding settings.""" + # If outgoing request encoding is gzip + # also compress the body + if self.has_header(b'content-encoding'): + if self.header(b'content-encoding') == b'gzip': + body = gzip.compress(body) + else: + # We only work with gzip, for any other encoding + # type, remove the original header + self.del_header(b'content-encoding') + # If the request is of type chunked encoding + # add post data as chunk + if self.is_chunked_encoded: + body = ChunkParser.to_chunks(body) + self.del_header(b'content-length') + else: + self.add_header( + b'Content-Length', + bytes_(len(body)), + ) + self.body = body + self.add_header(b'Content-Type', content_type) + + @property + def http_handler_protocol(self) -> int: + """Returns `HttpProtocols` that this request belongs to.""" + if self.version in (HTTP_1_1, HTTP_1_0) and self._url is not None: + if self.host is not None: + return httpProtocols.HTTP_PROXY + if self._url.hostname is None: + return httpProtocols.WEB_SERVER + return httpProtocols.UNKNOWN + + @property + def is_complete(self) -> bool: + return self.state == httpParserStates.COMPLETE + + @property + def is_http_1_1_keep_alive(self) -> bool: + """Returns true for HTTP/1.1 keep-alive connections.""" + return self.version == HTTP_1_1 and \ + ( + not self.has_header(b'Connection') or + self.header(b'Connection').lower() == b'keep-alive' + ) + + @property + def is_connection_upgrade(self) -> bool: + """Returns true for websocket upgrade requests.""" + return self.version == HTTP_1_1 and \ + self.has_header(b'Connection') and \ + self.has_header(b'Upgrade') + + @property + def is_https_tunnel(self) -> bool: + """Returns true for HTTPS CONNECT tunnel request.""" + return self._is_https_tunnel + + @property + def is_chunked_encoded(self) -> bool: + """Returns true if transfer-encoding chunked is used.""" + return self._is_chunked_encoded + + @property + def content_expected(self) -> bool: + """Returns true if content-length is present and not 0.""" + return self._content_expected + + @property + def body_expected(self) -> bool: + """Returns true if content or chunked response is expected.""" + return self._content_expected or self._is_chunked_encoded + + def parse( + self, + raw: memoryview, + allowed_url_schemes: Optional[List[bytes]] = None, + ) -> None: + """Parses HTTP request out of raw bytes. + + Check for `HttpParser.state` after `parse` has successfully returned.""" + size = len(raw) + self.total_size += size + if self.buffer: + # TODO(abhinavsingh): Instead of tobytes our parser + # must be capable of working with arrays of memoryview + raw = memoryview(self.buffer.tobytes() + raw.tobytes()) + self.buffer, more = None, size > 0 + while more and self.state != httpParserStates.COMPLETE: + # gte with HEADERS_COMPLETE also encapsulated RCVING_BODY state + if self.state >= httpParserStates.HEADERS_COMPLETE: + more, raw = self._process_body(raw) + elif self.state == httpParserStates.INITIALIZED: + more, raw = self._process_line( + raw, + allowed_url_schemes=allowed_url_schemes, + ) + else: + more, raw = self._process_headers(raw) + # When server sends a response line without any header or body e.g. + # HTTP/1.1 200 Connection established\r\n\r\n + if self.type == httpParserTypes.RESPONSE_PARSER and \ + self.state == httpParserStates.LINE_RCVD and \ + raw == CRLF: + self.state = httpParserStates.COMPLETE + # Mark request as complete if headers received and no incoming + # body indication received. + elif self.state == httpParserStates.HEADERS_COMPLETE and \ + not (self._content_expected or self._is_chunked_encoded) and \ + raw == b'': + self.state = httpParserStates.COMPLETE + self.buffer = None if raw == b'' else raw + + def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = False) -> bytes: + """Rebuild the request object.""" + assert self.method and self.version and self.type == httpParserTypes.REQUEST_PARSER + if disable_headers is None: + disable_headers = DEFAULT_DISABLE_HEADERS + body: Optional[bytes] = self._get_body_or_chunks() + path = self.path or b'/' + if for_proxy: + assert self.host and self.port and self._url + path = ( + b'http' if not self._url.scheme else self._url.scheme + + COLON + SLASH + SLASH + + self.host + + COLON + + str(self.port).encode() + + path + ) if not self._is_https_tunnel else (self.host + COLON + str(self.port).encode()) + return build_http_request( + self.method, path, self.version, + headers={} if not self.headers else { + self.headers[k][0]: self.headers[k][1] for k in self.headers if + k.lower() not in disable_headers + }, + body=body, + no_ua=True, + ) + + def build_response(self) -> bytes: + """Rebuild the response object.""" + assert self.code and self.version and self.type == httpParserTypes.RESPONSE_PARSER + return build_http_response( + status_code=int(self.code), + protocol_version=self.version, + reason=self.reason, + headers={} if not self.headers else { + self.headers[k][0]: self.headers[k][1] for k in self.headers + }, + body=self._get_body_or_chunks(), + ) + + def _process_body(self, raw: memoryview) -> Tuple[bool, memoryview]: + # Ref: http://www.ietf.org/rfc/rfc2616.txt + # 3.If a Content-Length header field (section 14.13) is present, its + # decimal value in OCTETs represents both the entity-length and the + # transfer-length. The Content-Length header field MUST NOT be sent + # if these two lengths are different (i.e., if a Transfer-Encoding + # header field is present). If a message is received with both a + # Transfer-Encoding header field and a Content-Length header field, + # the latter MUST be ignored. + # + # TL;DR -- Give transfer-encoding header preference over content-length. + if self._is_chunked_encoded: + if not self.chunk: + self.chunk = ChunkParser() + raw = self.chunk.parse(raw) + if self.chunk.state == chunkParserStates.COMPLETE: + self.body = self.chunk.body + self.state = httpParserStates.COMPLETE + more = False + return more, raw + if self._content_expected: + self.state = httpParserStates.RCVING_BODY + if self.body is None: + self.body = b'' + total_size = int(self.header(b'content-length')) + received_size = len(self.body) + self.body += raw[:total_size - received_size] + if self.body and \ + len(self.body) == int(self.header(b'content-length')): + self.state = httpParserStates.COMPLETE + return len(raw) > 0, raw[total_size - received_size:] + # Received a packet without content-length header + # and no transfer-encoding specified. + # + # This can happen for both HTTP/1.0 and HTTP/1.1 scenarios. + # Currently, we consume the remaining buffer as body. + # + # Ref https://github.com/abhinavsingh/proxy.py/issues/398 + # + # See TestHttpParser.test_issue_398 scenario + self.state = httpParserStates.RCVING_BODY + self.body = bytes(raw) + return False, memoryview(b'') + + def _process_headers(self, raw: memoryview) -> Tuple[bool, memoryview]: + """Returns False when no CRLF could be found in received bytes. + + TODO: We should not return until parser reaches headers complete + state or when there is no more data left to parse. + + TODO: For protection against Slowloris attack, we must parse the + request line and headers only after receiving end of header marker. + This will also help make the parser even more stateless. + """ + while True: + parts = raw.tobytes().split(CRLF, 1) + if len(parts) == 1: + return False, raw + line, raw = parts[0], memoryview(parts[1]) + if self.state in (httpParserStates.LINE_RCVD, httpParserStates.RCVING_HEADERS): + if line == b'' or line.strip() == b'': # Blank line received. + self.state = httpParserStates.HEADERS_COMPLETE + else: + self.state = httpParserStates.RCVING_HEADERS + self._process_header(line) + # If raw length is now zero, bail out + # If we have received all headers, bail out + if raw == b'' or self.state == httpParserStates.HEADERS_COMPLETE: + break + return len(raw) > 0, raw + + def _process_line( + self, + raw: memoryview, + allowed_url_schemes: Optional[List[bytes]] = None, + ) -> Tuple[bool, memoryview]: + while True: + parts = raw.tobytes().split(CRLF, 1) + if len(parts) == 1: + return False, raw + line, raw = parts[0], memoryview(parts[1]) + if self.type == httpParserTypes.REQUEST_PARSER: + if self.protocol is not None and self.protocol.version is None: + # We expect to receive entire proxy protocol v1 line + # in one network read and don't expect partial packets + self.protocol.parse(line) + continue + # Ref: https://datatracker.ietf.org/doc/html/rfc2616#section-5.1 + parts = line.split(WHITESPACE, 2) + if len(parts) == 3: + self.method = parts[0] + if self.method == httpMethods.CONNECT: + self._is_https_tunnel = True + self.set_url( + parts[1], allowed_url_schemes=allowed_url_schemes, + ) + self.version = parts[2] + self.state = httpParserStates.LINE_RCVD + break + # To avoid a possible attack vector, we raise exception + # if parser receives an invalid request line. + raise HttpProtocolException('Invalid request line %r' % raw) + parts = line.split(WHITESPACE, 2) + self.version = parts[0] + self.code = parts[1] + # Our own WebServerPlugin example currently doesn't send any reason + if len(parts) == 3: + self.reason = parts[2] + self.state = httpParserStates.LINE_RCVD + break + return len(raw) > 0, raw + + def _process_header(self, raw: bytes) -> None: + parts = raw.split(COLON, 1) + key, value = ( + parts[0].strip(), + b'' if len(parts) == 1 else parts[1].strip(), + ) + k = self.add_header(key, value) + # b'content-length' in self.headers and int(self.header(b'content-length')) > 0 + if k == b'content-length' and int(value) > 0: + self._content_expected = True + # return b'transfer-encoding' in self.headers and \ + # self.headers[b'transfer-encoding'][1].lower() == b'chunked' + elif k == b'transfer-encoding' and value.lower() == b'chunked': + self._is_chunked_encoded = True + + def _get_body_or_chunks(self) -> Optional[bytes]: + return ChunkParser.to_chunks(self.body) \ + if self.body and self._is_chunked_encoded else \ + self.body + + def _set_line_attributes(self) -> None: + if self.type == httpParserTypes.REQUEST_PARSER: + assert self._url + if self._is_https_tunnel: + self.host = self._url.hostname + self.port = 443 if self._url.port is None else self._url.port + else: + self.host, self.port = self._url.hostname, self._url.port \ + if self._url.port else DEFAULT_HTTP_PORT + self.path = self._url.remainder diff --git a/proxy/http/parser/protocol.py b/proxy/http/parser/protocol.py new file mode 100644 index 000000000..2fc19df48 --- /dev/null +++ b/proxy/http/parser/protocol.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Tuple, Optional + +from ..exception import HttpProtocolException +from ...common.constants import WHITESPACE + + +PROXY_PROTOCOL_V2_SIGNATURE = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A' + + +class ProxyProtocol: + """Reference https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt""" + + def __init__(self) -> None: + self.version: Optional[int] = None + self.family: Optional[bytes] = None + self.source: Optional[Tuple[bytes, int]] = None + self.destination: Optional[Tuple[bytes, int]] = None + + def parse(self, raw: bytes) -> None: + if raw.startswith(b'PROXY'): + self.version = 1 + # Per spec, v1 line cannot exceed this limit + assert len(raw) <= 57 + line = raw.split(WHITESPACE) + assert line[0] == b'PROXY' and line[1] in ( + b'TCP4', b'TCP6', b'UNKNOWN', + ) + self.family = line[1] + if len(line) == 6: + self.source = (line[2], int(line[4])) + self.destination = (line[3], int(line[5])) + else: + assert self.family == b'UNKNOWN' + elif raw.startswith(PROXY_PROTOCOL_V2_SIGNATURE): + self.version = 2 + raise NotImplementedError() + else: + raise HttpProtocolException( + 'Neither a v1 or v2 proxy protocol packet', + ) diff --git a/proxy/http/parser/types.py b/proxy/http/parser/types.py new file mode 100644 index 000000000..c5e019c6e --- /dev/null +++ b/proxy/http/parser/types.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable +""" +from typing import NamedTuple + + +HttpParserStates = NamedTuple( + 'HttpParserStates', [ + ('INITIALIZED', int), + ('LINE_RCVD', int), + ('RCVING_HEADERS', int), + ('HEADERS_COMPLETE', int), + ('RCVING_BODY', int), + ('COMPLETE', int), + ], +) +httpParserStates = HttpParserStates(1, 2, 3, 4, 5, 6) + +HttpParserTypes = NamedTuple( + 'HttpParserTypes', [ + ('REQUEST_PARSER', int), + ('RESPONSE_PARSER', int), + ], +) +httpParserTypes = HttpParserTypes(1, 2) diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py new file mode 100644 index 000000000..754fb2802 --- /dev/null +++ b/proxy/http/plugin.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import socket +import argparse +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Union, Optional + +from .parser import HttpParser +from .connection import HttpClientConnection +from ..core.event import EventQueue +from .descriptors import DescriptorsHandlerMixin +from ..common.utils import tls_interception_enabled + + +if TYPE_CHECKING: # pragma: no cover + from ..core.connection import UpstreamConnectionPool + + +class HttpProtocolHandlerPlugin( + DescriptorsHandlerMixin, + ABC, +): + """Base HttpProtocolHandler Plugin class. + + NOTE: This is an internal plugin and in most cases only useful for core contributors. + If you are looking for proxy server plugins see ``. + + Implements various lifecycle events for an accepted client connection. + Following events are of interest: + + 1. Client Connection Accepted + A new plugin instance is created per accepted client connection. + Add your logic within __init__ constructor for any per connection setup. + 2. Client Request Chunk Received + on_client_data is called for every chunk of data sent by the client. + 3. Client Request Complete + on_request_complete is called once client request has completed. + 4. Server Response Chunk Received + on_response_chunk is called for every chunk received from the server. + 5. Client Connection Closed + Add your logic within `on_client_connection_close` for any per connection tear-down. + """ + + def __init__( + self, + uid: str, + flags: argparse.Namespace, + client: HttpClientConnection, + request: HttpParser, + event_queue: Optional[EventQueue] = None, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, + ): + self.uid: str = uid + self.flags: argparse.Namespace = flags + self.client: HttpClientConnection = client + self.request: HttpParser = request + self.event_queue = event_queue + self.upstream_conn_pool = upstream_conn_pool + + @staticmethod + @abstractmethod + def protocols() -> List[int]: + raise NotImplementedError() + + @abstractmethod + def on_client_data(self, raw: memoryview) -> None: + """Called only after original request has been completely received.""" + pass # pragma: no cover + + @abstractmethod + def on_request_complete(self) -> Union[socket.socket, bool]: + """Called right after client request parser has reached COMPLETE state.""" + return False # pragma: no cover + + @abstractmethod + def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: + """Handle data chunks as received from the server. + + Return optionally modified chunk to return back to client.""" + return chunk # pragma: no cover + + @abstractmethod + def on_client_connection_close(self) -> None: + """Client connection shutdown has been received, flush has been called, + perform any cleanup work here. + """ + pass # pragma: no cover + + @property + def tls_interception_enabled(self) -> bool: + return tls_interception_enabled(self.flags) diff --git a/proxy/http/protocols.py b/proxy/http/protocols.py new file mode 100644 index 000000000..49485720c --- /dev/null +++ b/proxy/http/protocols.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable +""" +from typing import NamedTuple + + +HttpProtocols = NamedTuple( + 'HttpProtocols', [ + ('UNKNOWN', int), + # Web server handling HTTP/1.0, HTTP/1.1, HTTP/2, HTTP/3 + # over plain Text or encrypted connection with clients + ('WEB_SERVER', int), + # Proxies handling HTTP/1.0, HTTP/1.1, HTTP/2 protocols + # over plain text connection or encrypted connection + # with clients + ('HTTP_PROXY', int), + # Proxies handling SOCKS4, SOCKS4a, SOCKS5 protocol + ('SOCKS_PROXY', int), + ], +) + +httpProtocols = HttpProtocols(1, 2, 3, 4) diff --git a/proxy/http/proxy/__init__.py b/proxy/http/proxy/__init__.py index afd352711..a794ba078 100644 --- a/proxy/http/proxy/__init__.py +++ b/proxy/http/proxy/__init__.py @@ -11,6 +11,7 @@ from .plugin import HttpProxyBasePlugin from .server import HttpProxyPlugin + __all__ = [ 'HttpProxyBasePlugin', 'HttpProxyPlugin', diff --git a/proxy/http/proxy/auth.py b/proxy/http/proxy/auth.py new file mode 100644 index 000000000..0238d212d --- /dev/null +++ b/proxy/http/proxy/auth.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + auth + http +""" +from typing import Optional + +from ...http import httpHeaders +from ..exception import ProxyAuthenticationFailed +from ...http.proxy import HttpProxyBasePlugin +from ...http.parser import HttpParser + + +class AuthPlugin(HttpProxyBasePlugin): + """Performs proxy authentication.""" + + def before_upstream_connection( + self, request: HttpParser, + ) -> Optional[HttpParser]: + if self.flags.auth_code and request.headers: + if httpHeaders.PROXY_AUTHORIZATION not in request.headers: + raise ProxyAuthenticationFailed() + parts = request.headers[httpHeaders.PROXY_AUTHORIZATION][1].split() + if len(parts) != 2 \ + or parts[0].lower() != b'basic' \ + or parts[1] != self.flags.auth_code: + raise ProxyAuthenticationFailed() + return request diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index 44129ed86..7768e3c59 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -8,32 +8,43 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from abc import ABC, abstractmethod -from typing import Optional -from uuid import UUID +import argparse +from abc import ABC +from typing import TYPE_CHECKING, Any, Dict, Tuple, Optional + from ..parser import HttpParser +from ..connection import HttpClientConnection +from ...core.event import EventQueue +from ..descriptors import DescriptorsHandlerMixin +from ...common.utils import tls_interception_enabled -from ...common.flags import Flags -from ...core.event import EventQueue -from ...core.connection import TcpClientConnection +if TYPE_CHECKING: # pragma: no cover + from ...common.types import HostPort + from ...core.connection import UpstreamConnectionPool -class HttpProxyBasePlugin(ABC): +class HttpProxyBasePlugin( + DescriptorsHandlerMixin, + ABC, +): """Base HttpProxyPlugin Plugin class. Implement various lifecycle event methods to customize behavior.""" def __init__( self, - uid: UUID, - flags: Flags, - client: TcpClientConnection, - event_queue: EventQueue) -> None: + uid: str, + flags: argparse.Namespace, + client: HttpClientConnection, + event_queue: EventQueue, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, + ) -> None: self.uid = uid # pragma: no cover self.flags = flags # pragma: no cover self.client = client # pragma: no cover self.event_queue = event_queue # pragma: no cover + self.upstream_conn_pool = upstream_conn_pool def name(self) -> str: """A unique name for your plugin. @@ -42,9 +53,28 @@ def name(self) -> str: access a specific plugin by its name.""" return self.__class__.__name__ # pragma: no cover - @abstractmethod + def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['HostPort']]: + """Resolve upstream server host to an IP address. + + Optionally also override the source address to use for + connection with upstream server. + + For upstream IP: + Return None to use default resolver available to the system. + Return IP address as string to use your custom resolver. + + For source address: + Return None to use default source address + Return 2-tuple representing (host, port) to use as source address + """ + return None, None + + # No longer abstract since 2.4.0 + # + # @abstractmethod def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: """Handler called just before Proxy upstream connection is established. Return optionally modified request object. @@ -53,9 +83,31 @@ def before_upstream_connection( Raise HttpRequestRejected or HttpProtocolException directly to drop the connection.""" return request # pragma: no cover - @abstractmethod + # Since 3.4.0 + # + # @abstractmethod + def handle_client_data( + self, raw: memoryview, + ) -> Optional[memoryview]: + """Handler called in special scenarios when an upstream server connection + is never established. + + Essentially, if you return None from within before_upstream_connection, + be prepared to handle_client_data and not handle_client_request. + + Only called after initial request from client has been received. + + Raise HttpRequestRejected to tear down the connection + Return None to drop the connection + """ + return raw # pragma: no cover + + # No longer abstract since 2.4.0 + # + # @abstractmethod def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: + self, request: HttpParser, + ) -> Optional[HttpParser]: """Handler called before dispatching client request to upstream. Note: For pipelined (keep-alive) connections, this handler can be @@ -68,19 +120,52 @@ def handle_client_request( Return optionally modified request object to dispatch to upstream. Return None to drop the request data, e.g. in case a response has already been queued. Raise HttpRequestRejected or HttpProtocolException directly to - teardown the connection with client. + tear down the connection with client. """ return request # pragma: no cover - @abstractmethod - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: + # No longer abstract since 2.4.0 + # + # @abstractmethod + def handle_upstream_chunk(self, chunk: memoryview) -> Optional[memoryview]: """Handler called right after receiving raw response from upstream server. For HTTPS connections, chunk will be encrypted unless - TLS interception is also enabled.""" + TLS interception is also enabled. + + Return None if you don't want to sent this chunk to the client. + """ return chunk # pragma: no cover - @abstractmethod + # No longer abstract since 2.4.0 + # + # @abstractmethod def on_upstream_connection_close(self) -> None: """Handler called right after upstream connection has been closed.""" pass # pragma: no cover + + def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Use this method to override default access log format (see + DEFAULT_HTTP_ACCESS_LOG_FORMAT and DEFAULT_HTTPS_ACCESS_LOG_FORMAT) and to + add/update/modify/delete context for next plugin.on_access_log invocation. + + This is specially useful if a plugins want to provide extra context + in the access log which may not available within other plugins' context or even + in proxy.py core. + + Returns Log context or None. If plugin chooses to access log, they ideally + must return None to prevent other plugin.on_access_log invocation. + """ + return context + + def do_intercept(self, _request: HttpParser) -> bool: + """By default returns True (only) when necessary flags + for TLS interception are passed. + + When TLS interception is enabled, plugins can still disable + TLS interception by returning False explicitly. This hook + will allow you to run proxy instance with TLS interception + flags BUT only conditionally enable interception for + certain requests. + """ + return tls_interception_enabled(self.flags) diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 5ce2948db..facb31631 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -7,52 +7,113 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + reusability """ -import logging -import threading import os import ssl -import socket import time import errno -from typing import Optional, List, Union, Dict, cast, Any, Tuple +import socket +import logging +import threading +import subprocess +from typing import Any, Dict, List, Union, Optional, cast from .plugin import HttpProxyBasePlugin -from ..handler import HttpProtocolHandlerPlugin -from ..exception import HttpProtocolException, ProxyConnectionFailed, ProxyAuthenticationFailed -from ..codes import httpStatusCodes -from ..parser import HttpParser, httpParserStates, httpParserTypes +from ..parser import HttpParser, httpParserTypes, httpParserStates +from ..plugin import HttpProtocolHandlerPlugin +from ..headers import httpHeaders from ..methods import httpMethods - -from ...common.types import HasFileno -from ...common.constants import PROXY_AGENT_HEADER_VALUE -from ...common.utils import build_http_response, text_ -from ...common.pki import gen_public_key, gen_csr, sign_csr - +from ..exception import HttpProtocolException, ProxyConnectionFailed +from ..protocols import httpProtocols +from ..responses import PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT +from ...common.pki import gen_csr, sign_csr, gen_public_key from ...core.event import eventNames -from ...core.connection import TcpServerConnection, TcpConnectionUninitializedException +from ...common.flag import flags +from ...common.types import Readables, Writables, Descriptors +from ...common.utils import text_ +from ...core.connection import ( + TcpServerConnection, TcpConnectionUninitializedException, +) +from ...common.constants import ( + COMMA, DEFAULT_CA_FILE, DEFAULT_CA_CERT_DIR, + DEFAULT_CA_KEY_FILE, DEFAULT_CA_CERT_FILE, DEFAULT_DISABLE_HEADERS, + PROXY_AGENT_HEADER_VALUE, + DEFAULT_CA_SIGNING_KEY_FILE, DEFAULT_HTTP_PROXY_ACCESS_LOG_FORMAT, + DEFAULT_HTTPS_PROXY_ACCESS_LOG_FORMAT, +) + logger = logging.getLogger(__name__) +flags.add_argument( + '--disable-headers', + type=str, + default=COMMA.join(DEFAULT_DISABLE_HEADERS), + help='Default: None. Comma separated list of headers to remove before ' + 'dispatching client request to upstream server.', +) + +flags.add_argument( + '--ca-key-file', + type=str, + default=DEFAULT_CA_KEY_FILE, + help='Default: None. CA key to use for signing dynamically generated ' + 'HTTPS certificates. If used, must also pass --ca-cert-file and --ca-signing-key-file', +) + +flags.add_argument( + '--ca-cert-dir', + type=str, + default=DEFAULT_CA_CERT_DIR, + help='Default: ~/.proxy/certificates. Directory to store dynamically generated certificates. ' + 'Also see --ca-key-file, --ca-cert-file and --ca-signing-key-file', +) + +flags.add_argument( + '--ca-cert-file', + type=str, + default=DEFAULT_CA_CERT_FILE, + help='Default: None. Signing certificate to use for signing dynamically generated ' + 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-signing-key-file', +) + +flags.add_argument( + '--ca-file', + type=str, + default=str(DEFAULT_CA_FILE), + help='Default: ' + str(DEFAULT_CA_FILE) + + '. Provide path to custom CA bundle for peer certificate verification', +) + +flags.add_argument( + '--ca-signing-key-file', + type=str, + default=DEFAULT_CA_SIGNING_KEY_FILE, + help='Default: None. CA signing key to use for dynamic generation of ' + 'HTTPS certificates. If used, must also pass --ca-key-file and --ca-cert-file', +) + + class HttpProxyPlugin(HttpProtocolHandlerPlugin): """HttpProtocolHandler plugin which implements HttpProxy specifications.""" - PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'Connection established' - )) - - # Used to synchronize with other HttpProxyPlugin instances while - # generating certificates + # Used to synchronization during certificate generation and + # connection pool operations. lock = threading.Lock() def __init__( self, - *args: Any, **kwargs: Any) -> None: + *args: Any, **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.start_time: float = time.time() - self.server: Optional[TcpServerConnection] = None + self.upstream: Optional[TcpServerConnection] = None self.response: HttpParser = HttpParser(httpParserTypes.RESPONSE_PARSER) self.pipeline_request: Optional[HttpParser] = None self.pipeline_response: Optional[HttpParser] = None @@ -60,61 +121,104 @@ def __init__( self.plugins: Dict[str, HttpProxyBasePlugin] = {} if b'HttpProxyBasePlugin' in self.flags.plugins: for klass in self.flags.plugins[b'HttpProxyBasePlugin']: - instance = klass( + instance: HttpProxyBasePlugin = klass( self.uid, self.flags, self.client, - self.event_queue) + self.event_queue, + self.upstream_conn_pool, + ) self.plugins[instance.name()] = instance - def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: - if not self.request.has_upstream_server(): - return [], [] - - r: List[socket.socket] = [] - w: List[socket.socket] = [] - if self.server and not self.server.closed and self.server.connection: - r.append(self.server.connection) - if self.server and not self.server.closed and \ - self.server.has_buffer() and self.server.connection: - w.append(self.server.connection) + @staticmethod + def protocols() -> List[int]: + return [httpProtocols.HTTP_PROXY] + + async def get_descriptors(self) -> Descriptors: + r: List[int] = [] + w: List[int] = [] + if ( + self.upstream and + not self.upstream.closed and + self.upstream.connection + ): + r.append(self.upstream.connection.fileno()) + if ( + self.upstream and + not self.upstream.closed and + self.upstream.has_buffer() and + self.upstream.connection + ): + w.append(self.upstream.connection.fileno()) + # TODO(abhinavsingh): We need to keep a mapping of plugin and + # descriptors registered by them, so that within write/read blocks + # we can invoke the right plugin callbacks. + for plugin in self.plugins.values(): + plugin_read_desc, plugin_write_desc = await plugin.get_descriptors() + r.extend(plugin_read_desc) + w.extend(plugin_write_desc) return r, w - def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: - if self.request.has_upstream_server() and \ - self.server and not self.server.closed and \ - self.server.has_buffer() and \ - self.server.connection in w: - logger.debug('Server is write ready, flushing buffer') + async def write_to_descriptors(self, w: Writables) -> bool: + if (self.upstream and self.upstream.connection.fileno() not in w) or not self.upstream: + # Currently, we just call write/read block of each plugins. It is + # plugins responsibility to ignore this callback, if passed descriptors + # doesn't contain the descriptor they registered. + for plugin in self.plugins.values(): + teardown = await plugin.write_to_descriptors(w) + if teardown: + return True + elif self.upstream and not self.upstream.closed and \ + self.upstream.has_buffer() and \ + self.upstream.connection.fileno() in w: + logger.debug('Server is write ready, flushing...') try: - self.server.flush() + self.upstream.flush(self.flags.max_sendbuf_size) except ssl.SSLWantWriteError: - logger.warning('SSLWantWriteError while trying to flush to server, will retry') + logger.warning( + 'SSLWantWriteError while trying to flush to server, will retry', + ) return False except BrokenPipeError: - logger.error( - 'BrokenPipeError when flushing buffer for server') - return True + logger.warning( + 'BrokenPipeError when flushing buffer for server', + ) + return self._close_and_release() except OSError as e: - logger.exception('OSError when flushing buffer to server', exc_info=e) - return True + logger.exception( + 'OSError when flushing buffer to server', exc_info=e, + ) + return self._close_and_release() return False - def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: - if self.request.has_upstream_server( - ) and self.server and not self.server.closed and self.server.connection in r: - logger.debug('Server is ready for reads, reading...') + async def read_from_descriptors(self, r: Readables) -> bool: + if ( + self.upstream and not + self.upstream.closed and + self.upstream.connection.fileno() not in r + ) or not self.upstream: + # Currently, we just call write/read block of each plugins. It is + # plugins responsibility to ignore this callback, if passed descriptors + # doesn't contain the descriptor they registered for. + for plugin in self.plugins.values(): + teardown = await plugin.read_from_descriptors(r) + if teardown: + return True + elif self.upstream \ + and not self.upstream.closed \ + and self.upstream.connection.fileno() in r: + logger.debug('Server is read ready, receiving...') try: - raw = self.server.recv(self.flags.server_recvbuf_size) + raw = self.upstream.recv(self.flags.server_recvbuf_size) except TimeoutError as e: + self._close_and_release() if e.errno == errno.ETIMEDOUT: logger.warning( '%s:%d timed out on recv' % - self.server.addr) + self.upstream.addr, + ) return True - else: - raise e + raise e except ssl.SSLWantReadError: # Try again later # logger.warning('SSLWantReadError encountered while reading from server, will retry ...') return False @@ -122,74 +226,147 @@ def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: if e.errno == errno.EHOSTUNREACH: logger.warning( '%s:%d unreachable on recv' % - self.server.addr) - return True - elif e.errno == errno.ECONNRESET: - logger.warning('Connection reset by upstream: %r' % e) + self.upstream.addr, + ) + if e.errno == errno.ECONNRESET: + logger.warning( + 'Connection reset by upstream: {0}:{1}'.format( + *self.upstream.addr, + ), + ) else: - logger.exception( - 'Exception while receiving from %s connection %r with reason %r' % - (self.server.tag, self.server.connection, e)) - return True + logger.warning( + 'Exception while receiving from %s connection#%d with reason %r' % + (self.upstream.tag, self.upstream.connection.fileno(), e), + ) + return self._close_and_release() if raw is None: logger.debug('Server closed connection, tearing down...') - return True + return self._close_and_release() for plugin in self.plugins.values(): raw = plugin.handle_upstream_chunk(raw) + if raw is None: + break # parse incoming response packet # only for non-https requests and when # tls interception is enabled - if self.request.method != httpMethods.CONNECT: - # See https://github.com/abhinavsingh/proxy.py/issues/127 for why - # currently response parsing is disabled when TLS interception is enabled. - # - # or self.config.tls_interception_enabled(): - if self.response.state == httpParserStates.COMPLETE: - self.handle_pipeline_response(raw) + if raw is not None: + if ( + not self.request.is_https_tunnel + or self.tls_interception_enabled + ): + if self.response.is_complete: + self.handle_pipeline_response(raw) + else: + self.response.parse(raw) + self.emit_response_events(len(raw)) else: - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant - self.response.parse(raw.tobytes()) - self.emit_response_events() - else: - self.response.total_size += len(raw) - # queue raw data for client - self.client.queue(raw) + self.response.total_size += len(raw) + # queue raw data for client + self.client.queue(raw) return False def on_client_connection_close(self) -> None: - if not self.request.has_upstream_server(): - return - - self.access_log() - - # If server was never initialized, return - if self.server is None: - return + context = { + 'client_ip': None if not self.client.addr else self.client.addr[0], + 'client_port': None if not self.client.addr else self.client.addr[1], + 'server_host': text_(self.upstream.addr[0] if self.upstream else None), + 'server_port': text_(self.upstream.addr[1] if self.upstream else None), + 'connection_time_ms': '%.2f' % ((time.time() - self.start_time) * 1000), + # Request + 'request_method': text_(self.request.method), + 'request_path': text_(self.request.path), + 'request_bytes': text_(self.request.total_size), + 'request_ua': text_(self.request.header(b'user-agent')) + if self.request.has_header(b'user-agent') + else None, + 'request_version': text_(self.request.version), + # Response + 'response_bytes': self.response.total_size, + 'response_code': text_(self.response.code), + 'response_reason': text_(self.response.reason), + } + if self.flags.enable_proxy_protocol: + assert self.request.protocol and self.request.protocol.family + context.update({ + 'protocol': { + 'family': text_(self.request.protocol.family), + }, + }) + if self.request.protocol.source: + context.update({ + 'protocol': { + 'source_ip': text_(self.request.protocol.source[0]), + 'source_port': self.request.protocol.source[1], + }, + }) + if self.request.protocol.destination: + context.update({ + 'protocol': { + 'destination_ip': text_(self.request.protocol.destination[0]), + 'destination_port': self.request.protocol.destination[1], + }, + }) + + log_handled = False + for plugin in self.plugins.values(): + ctx = plugin.on_access_log(context) + if ctx is None: + log_handled = True + break + context = ctx + if not log_handled: + self.access_log(context) # Note that, server instance was initialized # but not necessarily the connection object exists. + # + # Unfortunately this is still being called when an upstream + # server connection was never established. This is done currently + # to assist proxy pool plugin to close its upstream proxy connections. + # + # In short, treat on_upstream_connection_close as on_client_connection_close + # equivalent within proxy plugins. + # # Invoke plugin.on_upstream_connection_close for plugin in self.plugins.values(): plugin.on_upstream_connection_close() + # If server was never initialized or was _close_and_release + if self.upstream is None: + return + + if self.flags.enable_conn_pool: + assert self.upstream_conn_pool + # Release the connection for reusability + with self.lock: + self.upstream_conn_pool.release(self.upstream) + return + try: try: - self.server.connection.shutdown(socket.SHUT_WR) + self.upstream.connection.shutdown(socket.SHUT_WR) except OSError: pass finally: # TODO: Unwrap if wrapped before close? - self.server.connection.close() + self.upstream.close() except TcpConnectionUninitializedException: pass finally: logger.debug( 'Closed server connection, has buffer %s' % - self.server.has_buffer()) + self.upstream.has_buffer(), + ) + + def access_log(self, log_attrs: Dict[str, Any]) -> None: + access_log_format = DEFAULT_HTTPS_PROXY_ACCESS_LOG_FORMAT + if not self.request.is_https_tunnel: + access_log_format = DEFAULT_HTTP_PROXY_ACCESS_LOG_FORMAT + logger.info(access_log_format.format_map(log_attrs)) def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: # TODO: Allow to output multiple access_log lines @@ -197,65 +374,89 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: # However, this must also be accompanied by resetting both request # and response objects. # - # if not self.request.method == httpMethods.CONNECT and \ - # self.response.state == httpParserStates.COMPLETE: + # if not self.request.is_https_tunnel and \ + # self.response.is_complete: # self.access_log() return chunk - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: - if not self.request.has_upstream_server(): - return raw - - if self.server and not self.server.closed: - if self.request.state == httpParserStates.COMPLETE and ( - self.request.method != httpMethods.CONNECT or - self.flags.tls_interception_enabled()): + # Can return None to tear down connection + def on_client_data(self, raw: memoryview) -> None: + # For scenarios when an upstream connection was never established, + # let plugin do whatever they wish to. These are special scenarios + # where plugins are trying to do something magical. Within the core + # we don't know the context. In fact, we are not even sure if data + # exchanged is http spec compliant. + # + # Hence, here we pass raw data to HTTP proxy plugins as is. + # + # We only call handle_client_data once original request has been + # completely received + if not self.upstream: + for plugin in self.plugins.values(): + o = plugin.handle_client_data(raw) + if o is None: + return + raw = o + elif self.upstream and not self.upstream.closed: + # For http proxy requests, handle pipeline case. + # We also handle pipeline scenario for https proxy + # requests is TLS interception is enabled. + if self.request.is_complete and ( + not self.request.is_https_tunnel or + self.tls_interception_enabled + ): if self.pipeline_request is not None and \ - self.pipeline_request.is_connection_upgrade(): + self.pipeline_request.is_connection_upgrade: # Previous pipelined request was a WebSocket # upgrade request. Incoming client data now # must be treated as WebSocket protocol packets. - self.server.queue(raw) - return None - + self.upstream.queue(raw) + return if self.pipeline_request is None: + # For pipeline requests, we never + # want to use --enable-proxy-protocol flag + # as proxy protocol header will not be present + # + # TODO: HTTP parser must be smart about detecting + # HA proxy protocol or we must always explicitly pass + # the flag when we are expecting HA proxy protocol + # request line before HTTP request lines. self.pipeline_request = HttpParser( - httpParserTypes.REQUEST_PARSER) - - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant - self.pipeline_request.parse(raw.tobytes()) - if self.pipeline_request.state == httpParserStates.COMPLETE: + httpParserTypes.REQUEST_PARSER, + ) + self.pipeline_request.parse(raw) + if self.pipeline_request.is_complete: for plugin in self.plugins.values(): assert self.pipeline_request is not None r = plugin.handle_client_request(self.pipeline_request) if r is None: - return None + return self.pipeline_request = r assert self.pipeline_request is not None # TODO(abhinavsingh): Remove memoryview wrapping here after # parser is fully memoryview compliant - self.server.queue( + self.upstream.queue( memoryview( - self.pipeline_request.build())) - if not self.pipeline_request.is_connection_upgrade(): + self.pipeline_request.build(), + ), + ) + if not self.pipeline_request.is_connection_upgrade: self.pipeline_request = None + # For scenarios where we cannot peek into the data, + # simply queue for upstream server. else: - self.server.queue(raw) - return None - else: - return raw + self.upstream.queue(raw) def on_request_complete(self) -> Union[socket.socket, bool]: - if not self.request.has_upstream_server(): - return False - self.emit_request_complete() - self.authenticate() - - # Note: can raise HttpRequestRejected exception # Invoke plugin.before_upstream_connection + # + # before_upstream_connection can: + # 1) Raise HttpRequestRejected exception to reject the connection + # 2) return None to continue without establishing an upstream server connection + # e.g. for scenarios when plugins want to return response from cache, or, + # via out-of-band over the network request. do_connect = True for plugin in self.plugins.values(): r = plugin.before_upstream_connection(self.request) @@ -264,9 +465,11 @@ def on_request_complete(self) -> Union[socket.socket, bool]: break self.request = r + # Connect to upstream if do_connect: self.connect_upstream() + # Invoke plugin.handle_client_request for plugin in self.plugins.values(): assert self.request is not None r = plugin.handle_client_request(self.request) @@ -275,88 +478,139 @@ def on_request_complete(self) -> Union[socket.socket, bool]: else: return False - if self.request.method == httpMethods.CONNECT: - self.client.queue( - HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) - # If interception is enabled - if self.flags.tls_interception_enabled(): - # Perform SSL/TLS handshake with upstream - self.wrap_server() - # Generate certificate and perform handshake with client - try: - # wrap_client also flushes client data before wrapping - # sending to client can raise, handle expected exceptions - self.wrap_client() - except BrokenPipeError: - logger.error( - 'BrokenPipeError when wrapping client') - return True - except OSError as e: - logger.exception( - 'OSError when wrapping client', exc_info=e) - return True - # Update all plugin connection reference - for plugin in self.plugins.values(): - plugin.client._conn = self.client.connection - return self.client.connection - elif self.server: - # - proxy-connection header is a mistake, it doesn't seem to be - # officially documented in any specification, drop it. - # - proxy-authorization is of no use for upstream, remove it. - self.request.del_headers( - [b'proxy-authorization', b'proxy-connection']) - # - For HTTP/1.0, connection header defaults to close - # - For HTTP/1.1, connection header defaults to keep-alive - # Respect headers sent by client instead of manipulating - # Connection or Keep-Alive header. However, note that per - # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection - # connection headers are meant for communication between client and - # first intercepting proxy. - self.request.add_headers( - [(b'Via', b'1.1 %s' % PROXY_AGENT_HEADER_VALUE)]) - # Disable args.disable_headers before dispatching to upstream - self.server.queue( - memoryview(self.request.build( - disable_headers=self.flags.disable_headers))) + # For https requests, respond back with tunnel established response. + # Optionally, setup interceptor if TLS interception is enabled. + if self.upstream: + if self.request.is_https_tunnel: + self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + if self.tls_interception_enabled: + # Check if any plugin wants to + # disable interception even + # with flags available + do_intercept = True + for plugin in self.plugins.values(): + do_intercept = plugin.do_intercept(self.request) + # A plugin requested to not intercept + # the request + if do_intercept is False: + break + if do_intercept: + return self.intercept() + # If an upstream server connection was established for http request, + # queue the request for upstream server. + else: + # - proxy-connection header is a mistake, it doesn't seem to be + # officially documented in any specification, drop it. + # - proxy-authorization is of no use for upstream, remove it. + self.request.del_headers( + [ + httpHeaders.PROXY_AUTHORIZATION, + httpHeaders.PROXY_CONNECTION, + ], + ) + # - For HTTP/1.0, connection header defaults to close + # - For HTTP/1.1, connection header defaults to keep-alive + # Respect headers sent by client instead of manipulating + # Connection or Keep-Alive header. However, note that per + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection + # connection headers are meant for communication between client and + # first intercepting proxy. + self.request.add_headers( + [(b'Via', b'1.1 %s' % PROXY_AGENT_HEADER_VALUE)], + ) + # Disable args.disable_headers before dispatching to upstream + self.upstream.queue( + memoryview( + self.request.build( + disable_headers=self.flags.disable_headers, + ), + ), + ) return False def handle_pipeline_response(self, raw: memoryview) -> None: if self.pipeline_response is None: self.pipeline_response = HttpParser( - httpParserTypes.RESPONSE_PARSER) - # TODO(abhinavsingh): Remove .tobytes after parser is memoryview - # compliant - self.pipeline_response.parse(raw.tobytes()) - if self.pipeline_response.state == httpParserStates.COMPLETE: + httpParserTypes.RESPONSE_PARSER, + ) + self.pipeline_response.parse(raw) + if self.pipeline_response.is_complete: self.pipeline_response = None - def access_log(self) -> None: - server_host, server_port = self.server.addr if self.server else ( - None, None) - connection_time_ms = (time.time() - self.start_time) * 1000 - if self.request.method == b'CONNECT': - logger.info( - '%s:%s - %s %s:%s - %s bytes - %.2f ms' % - (self.client.addr[0], - self.client.addr[1], - text_(self.request.method), - text_(server_host), - text_(server_port), - self.response.total_size, - connection_time_ms)) - elif self.request.method: - logger.info( - '%s:%s - %s %s:%s%s - %s %s - %s bytes - %.2f ms' % - (self.client.addr[0], self.client.addr[1], - text_(self.request.method), - text_(server_host), server_port, - text_(self.request.path), - text_(self.response.code), - text_(self.response.reason), - self.response.total_size, - connection_time_ms)) - - def gen_ca_signed_certificate(self, cert_file_path: str, certificate: Dict[str, Any]) -> None: + def connect_upstream(self) -> None: + host, port = self.request.host, self.request.port + if host and port: + try: + # Invoke plugin.resolve_dns + upstream_ip, source_addr = None, None + for plugin in self.plugins.values(): + upstream_ip, source_addr = plugin.resolve_dns( + text_(host), port, + ) + if upstream_ip or source_addr: + break + logger.debug( + 'Connecting to upstream %s:%d' % + (text_(host), port), + ) + if self.flags.enable_conn_pool: + assert self.upstream_conn_pool + with self.lock: + created, self.upstream = self.upstream_conn_pool.acquire( + (text_(host), port), + ) + else: + created, self.upstream = True, TcpServerConnection( + text_(host), port, + ) + # Connect with overridden upstream IP and source address + # if any of the plugin returned a non-null value. + self.upstream.connect( + addr=None if not upstream_ip else ( + upstream_ip, port, + ), source_address=source_addr, + ) + self.upstream.connection.setblocking(False) + if not created: + # NOTE: Acquired connection might be in an unusable state. + # + # This can only be confirmed by reading from connection. + # For stale connections, we will receive None, indicating + # to drop the connection. + # + # If that happen, we must acquire a fresh connection. + logger.info( + 'Reusing connection to upstream %s:%d' % + (text_(host), port), + ) + return + logger.debug( + 'Connected to upstream %s:%s' % + (text_(host), port), + ) + except Exception as e: # TimeoutError, socket.gaierror + logger.warning( + 'Unable to connect with upstream %s:%d due to %s' % ( + text_(host), port, str(e), + ), + ) + if self.flags.enable_conn_pool and self.upstream: + assert self.upstream_conn_pool + with self.lock: + self.upstream_conn_pool.release(self.upstream) + raise ProxyConnectionFailed( + text_(host), port, repr(e), + ) from e + else: + raise HttpProtocolException('Both host and port must exist') + + # + # Interceptor related methods + # + + def gen_ca_signed_certificate( + self, cert_file_path: str, certificate: Dict[str, Any], + ) -> None: '''CA signing key (default) is used for generating a public key for common_name, if one already doesn't exist. Using generated public key a CSR request is generated, which is then signed by @@ -364,55 +618,80 @@ def gen_ca_signed_certificate(self, cert_file_path: str, certificate: Dict[str, certificate doesn't already exist. returns signed certificate path.''' - assert(self.request.host and self.flags.ca_cert_dir and self.flags.ca_signing_key_file and - self.flags.ca_key_file and self.flags.ca_cert_file) + assert( + self.request.host and self.flags.ca_cert_dir and self.flags.ca_signing_key_file and + self.flags.ca_key_file and self.flags.ca_cert_file + ) upstream_subject = {s[0][0]: s[0][1] for s in certificate['subject']} - public_key_path = os.path.join(self.flags.ca_cert_dir, - '{0}.{1}'.format(text_(self.request.host), 'pub')) + public_key_path = os.path.join( + self.flags.ca_cert_dir, + '{0}.{1}'.format(text_(self.request.host), 'pub'), + ) private_key_path = self.flags.ca_signing_key_file private_key_password = '' - subject = '/CN={0}/C={1}/ST={2}/L={3}/O={4}/OU={5}'.format( - upstream_subject.get('commonName', text_(self.request.host)), - upstream_subject.get('countryName', 'NA'), - upstream_subject.get('stateOrProvinceName', 'Unavailable'), - upstream_subject.get('localityName', 'Unavailable'), - upstream_subject.get('organizationName', 'Unavailable'), - upstream_subject.get('organizationalUnitName', 'Unavailable')) - alt_subj_names = [text_(self.request.host), ] + + # Build certificate subject + keys = { + 'CN': 'commonName', + 'C': 'countryName', + 'ST': 'stateOrProvinceName', + 'L': 'localityName', + 'O': 'organizationName', + 'OU': 'organizationalUnitName', + } + subject = '' + for key in keys: + if upstream_subject.get(keys[key], None): + subject += '/{0}={1}'.format( + key, + upstream_subject.get(keys[key]), + ) + alt_subj_names = [text_(self.request.host)] validity_in_days = 365 * 2 timeout = 10 # Generate a public key for the common name if not os.path.isfile(public_key_path): logger.debug('Generating public key %s', public_key_path) - resp = gen_public_key(public_key_path=public_key_path, private_key_path=private_key_path, - private_key_password=private_key_password, subject=subject, alt_subj_names=alt_subj_names, - validity_in_days=validity_in_days, timeout=timeout) + resp = gen_public_key( + public_key_path=public_key_path, private_key_path=private_key_path, + private_key_password=private_key_password, subject=subject, alt_subj_names=alt_subj_names, + validity_in_days=validity_in_days, timeout=timeout, + openssl=self.flags.openssl, + ) assert(resp is True) - csr_path = os.path.join(self.flags.ca_cert_dir, - '{0}.{1}'.format(text_(self.request.host), 'csr')) + csr_path = os.path.join( + self.flags.ca_cert_dir, + '{0}.{1}'.format(text_(self.request.host), 'csr'), + ) # Generate a CSR request for this common name if not os.path.isfile(csr_path): logger.debug('Generating CSR %s', csr_path) - resp = gen_csr(csr_path=csr_path, key_path=private_key_path, password=private_key_password, - crt_path=public_key_path, timeout=timeout) + resp = gen_csr( + csr_path=csr_path, key_path=private_key_path, password=private_key_password, + crt_path=public_key_path, timeout=timeout, + openssl=self.flags.openssl, + ) assert(resp is True) ca_key_path = self.flags.ca_key_file ca_key_password = '' ca_crt_path = self.flags.ca_cert_file - serial = self.uid.int + serial = '%d%d' % (time.time(), os.getpid()) # Sign generated CSR if not os.path.isfile(cert_file_path): logger.debug('Signing CSR %s', cert_file_path) - resp = sign_csr(csr_path=csr_path, crt_path=cert_file_path, ca_key_path=ca_key_path, - ca_key_password=ca_key_password, ca_crt_path=ca_crt_path, - serial=str(serial), alt_subj_names=alt_subj_names, - validity_in_days=validity_in_days, timeout=timeout) + resp = sign_csr( + csr_path=csr_path, crt_path=cert_file_path, ca_key_path=ca_key_path, + ca_key_password=ca_key_password, ca_crt_path=ca_crt_path, + serial=str(serial), alt_subj_names=alt_subj_names, + validity_in_days=validity_in_days, timeout=timeout, + openssl=self.flags.openssl, + ) assert(resp is True) @staticmethod @@ -420,118 +699,242 @@ def generated_cert_file_path(ca_cert_dir: str, host: str) -> str: return os.path.join(ca_cert_dir, '%s.pem' % host) def generate_upstream_certificate( - self, certificate: Dict[str, Any]) -> str: - if not (self.flags.ca_cert_dir and self.flags.ca_signing_key_file and - self.flags.ca_cert_file and self.flags.ca_key_file): + self, certificate: Dict[str, Any], + ) -> str: + if not ( + self.flags.ca_cert_dir and self.flags.ca_signing_key_file and + self.flags.ca_cert_file and self.flags.ca_key_file + ): raise HttpProtocolException( f'For certificate generation all the following flags are mandatory: ' f'--ca-cert-file:{ self.flags.ca_cert_file }, ' f'--ca-key-file:{ self.flags.ca_key_file }, ' - f'--ca-signing-key-file:{ self.flags.ca_signing_key_file }') + f'--ca-signing-key-file:{ self.flags.ca_signing_key_file }', + ) cert_file_path = HttpProxyPlugin.generated_cert_file_path( - self.flags.ca_cert_dir, text_(self.request.host)) + self.flags.ca_cert_dir, text_(self.request.host), + ) with self.lock: if not os.path.isfile(cert_file_path): self.gen_ca_signed_certificate(cert_file_path, certificate) return cert_file_path - def wrap_server(self) -> None: - assert self.server is not None - assert isinstance(self.server.connection, socket.socket) - ctx = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH, cafile=self.flags.ca_file) - ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 - ctx.check_hostname = True - self.server.connection.setblocking(True) - self.server._conn = ctx.wrap_socket( - self.server.connection, - server_hostname=text_(self.request.host)) - self.server.connection.setblocking(False) - - def wrap_client(self) -> None: - assert self.server is not None - assert isinstance(self.server.connection, ssl.SSLSocket) - generated_cert = self.generate_upstream_certificate( - cast(Dict[str, Any], self.server.connection.getpeercert())) - self.client.connection.setblocking(True) - self.client.flush() - self.client._conn = ssl.wrap_socket( - self.client.connection, - server_side=True, - certfile=generated_cert, - keyfile=self.flags.ca_signing_key_file, - ssl_version=ssl.PROTOCOL_TLSv1_2) - self.client.connection.setblocking(False) - logger.debug( - 'TLS interception using %s', generated_cert) - - def authenticate(self) -> None: - if self.flags.auth_code: - if b'proxy-authorization' not in self.request.headers or \ - self.request.headers[b'proxy-authorization'][1] != self.flags.auth_code: - raise ProxyAuthenticationFailed() + def intercept(self) -> Union[socket.socket, bool]: + # Perform SSL/TLS handshake with upstream + teardown = self.wrap_server() + if teardown: + return teardown + # Generate certificate and perform handshake with client + # wrap_client also flushes client data before wrapping + # sending to client can raise, handle expected exceptions + teardown = self.wrap_client() + if teardown: + return teardown + # Update all plugin connection reference + # TODO(abhinavsingh): Is this required? + for plugin in self.plugins.values(): + plugin.client._conn = self.client.connection + return self.client.connection - def connect_upstream(self) -> None: - host, port = self.request.host, self.request.port - if host and port: - self.server = TcpServerConnection(text_(host), port) - try: - logger.debug( - 'Connecting to upstream %s:%s' % - (text_(host), port)) - self.server.connect() - self.server.connection.setblocking(False) - logger.debug( - 'Connected to upstream %s:%s' % - (text_(host), port)) - except Exception as e: # TimeoutError, socket.gaierror - self.server.closed = True - raise ProxyConnectionFailed(text_(host), port, repr(e)) from e - else: - logger.exception('Both host and port must exist') - raise HttpProtocolException() + def wrap_server(self) -> bool: + assert self.upstream is not None + assert isinstance(self.upstream.connection, socket.socket) + do_close = False + try: + self.upstream.wrap( + text_(self.request.host), + self.flags.ca_file, + as_non_blocking=True, + ) + except ssl.SSLCertVerificationError: # Server raised certificate verification error + # When --disable-interception-on-ssl-cert-verification-error flag is on, + # we will cache such upstream hosts and avoid intercepting them for future + # requests. + logger.warning( + 'ssl.SSLCertVerificationError: ' + + 'Server raised cert verification error for upstream: {0}'.format( + self.upstream.addr[0], + ), + ) + do_close = True + except ssl.SSLError as e: + if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE': + logger.warning( + '{0}: '.format(e.reason) + + 'Server raised handshake alert failure for upstream: {0}'.format( + self.upstream.addr[0], + ), + ) + else: + logger.exception( + 'SSLError when wrapping client for upstream: {0}'.format( + self.upstream.addr[0], + ), exc_info=e, + ) + do_close = True + if not do_close: + assert isinstance(self.upstream.connection, ssl.SSLSocket) + return do_close + + def wrap_client(self) -> bool: + assert self.upstream is not None and self.flags.ca_signing_key_file is not None + assert isinstance(self.upstream.connection, ssl.SSLSocket) + do_close = False + try: + # TODO: Perform async certificate generation + generated_cert = self.generate_upstream_certificate( + cast(Dict[str, Any], self.upstream.connection.getpeercert()), + ) + self.client.wrap(self.flags.ca_signing_key_file, generated_cert) + except subprocess.TimeoutExpired as e: # Popen communicate timeout + logger.exception( + 'TimeoutExpired during certificate generation', exc_info=e, + ) + do_close = True + except ssl.SSLCertVerificationError: # Client raised certificate verification error + # When --disable-interception-on-ssl-cert-verification-error flag is on, + # we will cache such upstream hosts and avoid intercepting them for future + # requests. + logger.warning( + 'ssl.SSLCertVerificationError: ' + + 'Client raised cert verification error for upstream: {0}'.format( + self.upstream.addr[0], + ), + ) + do_close = True + except ssl.SSLEOFError as e: + logger.warning( + 'ssl.SSLEOFError {0} when wrapping client for upstream: {1}'.format( + str(e), self.upstream.addr[0], + ), + ) + do_close = True + except ssl.SSLError as e: + if e.reason in ('TLSV1_ALERT_UNKNOWN_CA', 'UNSUPPORTED_PROTOCOL'): + logger.warning( + '{0}: '.format(e.reason) + + 'Client raised cert verification error for upstream: {0}'.format( + self.upstream.addr[0], + ), + ) + else: + logger.exception( + 'OSError when wrapping client for upstream: {0}'.format( + self.upstream.addr[0], + ), exc_info=e, + ) + do_close = True + except BrokenPipeError: + logger.warning( + 'BrokenPipeError when wrapping client for upstream: {0}'.format( + self.upstream.addr[0], + ), + ) + do_close = True + except OSError as e: + logger.exception( + 'OSError when wrapping client for upstream: {0}'.format( + self.upstream.addr[0], + ), exc_info=e, + ) + do_close = True + if not do_close: + logger.debug('TLS intercepting using %s', generated_cert) + return do_close + + # + # Event emitter callbacks + # def emit_request_complete(self) -> None: if not self.flags.enable_events: return - - assert self.request.path - assert self.request.port + assert self.request.port and self.event_queue self.event_queue.publish( - request_id=self.uid.hex, + request_id=self.uid, event_name=eventNames.REQUEST_COMPLETE, event_payload={ 'url': text_(self.request.path) - if self.request.method == httpMethods.CONNECT + if self.request.is_https_tunnel else 'http://%s:%d%s' % (text_(self.request.host), self.request.port, text_(self.request.path)), 'method': text_(self.request.method), - 'headers': {text_(k): text_(v[1]) for k, v in self.request.headers.items()}, - 'body': text_(self.request.body) + 'headers': {} + if not self.request.headers else + { + text_(k): text_(v[1]) + for k, v in self.request.headers.items() + }, + 'body': text_(self.request.body, errors='ignore') if self.request.method == httpMethods.POST - else None + else None, }, - publisher_id=self.__class__.__name__ + publisher_id=self.__class__.__name__, ) - def emit_response_events(self) -> None: + def emit_response_events(self, chunk_size: int) -> None: if not self.flags.enable_events: return - - if self.response.state == httpParserStates.COMPLETE: + if self.response.is_complete: self.emit_response_complete() elif self.response.state == httpParserStates.RCVING_BODY: - self.emit_response_chunk_received() + self.emit_response_chunk_received(chunk_size) elif self.response.state == httpParserStates.HEADERS_COMPLETE: self.emit_response_headers_complete() def emit_response_headers_complete(self) -> None: if not self.flags.enable_events: return + assert self.event_queue + self.event_queue.publish( + request_id=self.uid, + event_name=eventNames.RESPONSE_HEADERS_COMPLETE, + event_payload={ + 'headers': {} + if not self.response.headers else + { + text_(k): text_(v[1]) + for k, v in self.response.headers.items() + }, + }, + publisher_id=self.__class__.__name__, + ) - def emit_response_chunk_received(self) -> None: + def emit_response_chunk_received(self, chunk_size: int) -> None: if not self.flags.enable_events: return + assert self.event_queue + self.event_queue.publish( + request_id=self.uid, + event_name=eventNames.RESPONSE_CHUNK_RECEIVED, + event_payload={ + 'chunk_size': chunk_size, + 'encoded_chunk_size': chunk_size, + }, + publisher_id=self.__class__.__name__, + ) def emit_response_complete(self) -> None: if not self.flags.enable_events: return + assert self.event_queue + self.event_queue.publish( + request_id=self.uid, + event_name=eventNames.RESPONSE_COMPLETE, + event_payload={ + 'encoded_response_size': self.response.total_size, + }, + publisher_id=self.__class__.__name__, + ) + + # + # Internal methods + # + + def _close_and_release(self) -> bool: + if self.flags.enable_conn_pool: + assert self.upstream and not self.upstream.closed and self.upstream_conn_pool + self.upstream.closed = True + with self.lock: + self.upstream_conn_pool.release(self.upstream) + self.upstream = None + return True diff --git a/proxy/http/responses.py b/proxy/http/responses.py new file mode 100644 index 000000000..d1e722fca --- /dev/null +++ b/proxy/http/responses.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import gzip +from typing import Any, Dict, Optional + +from .codes import httpStatusCodes +from ..common.utils import build_http_response +from ..common.constants import ( + PROXY_AGENT_HEADER_KEY, PROXY_AGENT_HEADER_VALUE, + DEFAULT_MIN_COMPRESSION_LENGTH, +) + + +PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.OK, + reason=b'Connection established', + no_cl=True, + ), +) + +PROXY_TUNNEL_UNSUPPORTED_SCHEME = memoryview( + build_http_response( + httpStatusCodes.BAD_REQUEST, + reason=b'Unsupported protocol scheme', + conn_close=True, + no_cl=True, + ), +) + +PROXY_AUTH_FAILED_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.PROXY_AUTH_REQUIRED, + reason=b'Proxy Authentication Required', + headers={ + PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, + b'Proxy-Authenticate': b'Basic', + }, + body=b'Proxy Authentication Required', + conn_close=True, + no_cl=True, + ), +) + +BAD_REQUEST_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.BAD_REQUEST, + reason=b'BAD REQUEST', + headers={ + b'Server': PROXY_AGENT_HEADER_VALUE, + }, + conn_close=True, + ), +) + +NOT_FOUND_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.NOT_FOUND, + reason=b'NOT FOUND', + headers={ + b'Server': PROXY_AGENT_HEADER_VALUE, + }, + conn_close=True, + ), +) + +NOT_IMPLEMENTED_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.NOT_IMPLEMENTED, + reason=b'NOT IMPLEMENTED', + headers={ + b'Server': PROXY_AGENT_HEADER_VALUE, + }, + conn_close=True, + ), +) + +BAD_GATEWAY_RESPONSE_PKT = memoryview( + build_http_response( + httpStatusCodes.BAD_GATEWAY, + reason=b'Bad Gateway', + headers={ + PROXY_AGENT_HEADER_KEY: PROXY_AGENT_HEADER_VALUE, + }, + body=b'Bad Gateway', + conn_close=True, + no_cl=True, + ), +) + + +def okResponse( + content: Optional[bytes] = None, + headers: Optional[Dict[bytes, bytes]] = None, + compress: bool = True, + min_compression_length: int = DEFAULT_MIN_COMPRESSION_LENGTH, + **kwargs: Any, +) -> memoryview: + do_compress: bool = False + if compress and content and len(content) > min_compression_length: + do_compress = True + if not headers: + headers = {} + headers.update({ + b'Content-Encoding': b'gzip', + }) + return memoryview( + build_http_response( + 200, + reason=b'OK', + headers=headers, + body=gzip.compress(content) + if do_compress and content + else content, + **kwargs, + ), + ) + + +def permanentRedirectResponse(location: bytes) -> memoryview: + return memoryview( + build_http_response( + httpStatusCodes.PERMANENT_REDIRECT, + reason=b'Permanent Redirect', + headers={ + b'Location': location, + b'Content-Length': b'0', + }, + conn_close=True, + ), + ) + + +def seeOthersResponse(location: bytes) -> memoryview: + return memoryview( + build_http_response( + httpStatusCodes.SEE_OTHER, + reason=b'See Other', + headers={ + b'Location': location, + b'Content-Length': b'0', + }, + conn_close=True, + ), + ) diff --git a/proxy/http/server/__init__.py b/proxy/http/server/__init__.py index 059c2cc12..dfbaa02c8 100644 --- a/proxy/http/server/__init__.py +++ b/proxy/http/server/__init__.py @@ -9,9 +9,10 @@ :license: BSD, see LICENSE for more details. """ from .web import HttpWebServerPlugin -from .pac_plugin import HttpWebServerPacFilePlugin from .plugin import HttpWebServerBasePlugin from .protocols import httpProtocolTypes +from .pac_plugin import HttpWebServerPacFilePlugin + __all__ = [ 'HttpWebServerPlugin', diff --git a/proxy/plugin/cache/store/__init__.py b/proxy/http/server/middleware.py similarity index 71% rename from proxy/plugin/cache/store/__init__.py rename to proxy/http/server/middleware.py index 232621f0b..ba52947fd 100644 --- a/proxy/plugin/cache/store/__init__.py +++ b/proxy/http/server/middleware.py @@ -8,3 +8,9 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ + + +class HttpWebServerBaseMiddleware: + """Web Server Middle-ware for customization during request/response dispatch lifecycle.""" + + pass diff --git a/proxy/http/server/pac_plugin.py b/proxy/http/server/pac_plugin.py index 0dfa0b490..a52f213dd 100644 --- a/proxy/http/server/pac_plugin.py +++ b/proxy/http/server/pac_plugin.py @@ -7,15 +7,37 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + pac """ -import gzip -from typing import List, Tuple, Optional, Any +from typing import Any, List, Tuple, Optional from .plugin import HttpWebServerBasePlugin -from .protocols import httpProtocolTypes -from ..websocket import WebsocketFrame from ..parser import HttpParser -from ...common.utils import bytes_, text_, build_http_response +from .protocols import httpProtocolTypes +from ..responses import okResponse +from ...common.flag import flags +from ...common.utils import text_, bytes_ +from ...common.constants import DEFAULT_PAC_FILE, DEFAULT_PAC_FILE_URL_PATH + + +flags.add_argument( + '--pac-file', + type=str, + default=DEFAULT_PAC_FILE, + help='A file (Proxy Auto Configuration) or string to serve when ' + 'the server receives a direct file request. ' + 'Using this option enables proxy.HttpWebServerPlugin.', +) +flags.add_argument( + '--pac-file-url-path', + type=str, + default=text_(DEFAULT_PAC_FILE_URL_PATH), + help='Default: %s. Web server path to serve the PAC file.' % + text_(DEFAULT_PAC_FILE_URL_PATH), +) class HttpWebServerPacFilePlugin(HttpWebServerBasePlugin): @@ -28,8 +50,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def routes(self) -> List[Tuple[int, str]]: if self.flags.pac_file_url_path: return [ - (httpProtocolTypes.HTTP, text_(self.flags.pac_file_url_path)), - (httpProtocolTypes.HTTPS, text_(self.flags.pac_file_url_path)), + ( + httpProtocolTypes.HTTP, r'{0}$'.format( + text_(self.flags.pac_file_url_path), + ), + ), + ( + httpProtocolTypes.HTTPS, r'{0}$'.format( + text_(self.flags.pac_file_url_path), + ), + ), ] return [] # pragma: no cover @@ -37,15 +67,6 @@ def handle_request(self, request: HttpParser) -> None: if self.flags.pac_file and self.pac_file_response: self.client.queue(self.pac_file_response) - def on_websocket_open(self) -> None: - pass # pragma: no cover - - def on_websocket_message(self, frame: WebsocketFrame) -> None: - pass # pragma: no cover - - def on_websocket_close(self) -> None: - pass # pragma: no cover - def cache_pac_file_response(self) -> None: if self.flags.pac_file: try: @@ -53,9 +74,11 @@ def cache_pac_file_response(self) -> None: content = f.read() except IOError: content = bytes_(self.flags.pac_file) - self.pac_file_response = memoryview(build_http_response( - 200, reason=b'OK', headers={ + self.pac_file_response = okResponse( + content=content, + headers={ b'Content-Type': b'application/x-ns-proxy-autoconfig', - b'Content-Encoding': b'gzip', - }, body=gzip.compress(content) - )) + }, + conn_close=True, + compress=False, + ) diff --git a/proxy/http/server/plugin.py b/proxy/http/server/plugin.py index a1e17c1e6..544d39ab8 100644 --- a/proxy/http/server/plugin.py +++ b/proxy/http/server/plugin.py @@ -8,30 +8,71 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +import argparse +import mimetypes from abc import ABC, abstractmethod -from typing import List, Tuple -from uuid import UUID -from ..websocket import WebsocketFrame -from ..parser import HttpParser +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, Optional -from ...common.flags import Flags -from ...core.connection import TcpClientConnection +from ..parser import HttpParser +from ...http.url import Url +from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse +from ..websocket import WebsocketFrame +from ..connection import HttpClientConnection from ...core.event import EventQueue +from ..descriptors import DescriptorsHandlerMixin +from ...common.types import RePattern +from ...common.utils import bytes_ + +if TYPE_CHECKING: # pragma: no cover + from ...core.connection import UpstreamConnectionPool -class HttpWebServerBasePlugin(ABC): + +class HttpWebServerBasePlugin(DescriptorsHandlerMixin, ABC): """Web Server Plugin for routing of requests.""" def __init__( self, - uid: UUID, - flags: Flags, - client: TcpClientConnection, - event_queue: EventQueue): + uid: str, + flags: argparse.Namespace, + client: HttpClientConnection, + event_queue: EventQueue, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, + ): self.uid = uid self.flags = flags self.client = client self.event_queue = event_queue + self.upstream_conn_pool = upstream_conn_pool + + @staticmethod + def serve_static_file(path: str, min_compression_length: int) -> memoryview: + try: + with open(path, 'rb') as f: + content = f.read() + content_type = mimetypes.guess_type(path)[0] + if content_type is None: + content_type = 'text/plain' + headers = { + b'Content-Type': bytes_(content_type), + b'Cache-Control': b'max-age=86400', + } + return okResponse( + content=content, + headers=headers, + min_compression_length=min_compression_length, + # TODO: Should we really close or take advantage of keep-alive? + conn_close=True, + ) + except FileNotFoundError: + return NOT_FOUND_RESPONSE_PKT + + def name(self) -> str: + """A unique name for your plugin. + + Defaults to name of the class. This helps plugin developers to directly + access a specific plugin by its name.""" + return self.__class__.__name__ # pragma: no cover @abstractmethod def routes(self) -> List[Tuple[int, str]]: @@ -43,17 +84,101 @@ def handle_request(self, request: HttpParser) -> None: """Handle the request and serve response.""" raise NotImplementedError() # pragma: no cover - @abstractmethod + def on_client_connection_close(self) -> None: + """Client has closed the connection, do any clean up task now.""" + pass + + # No longer abstract since v2.4.0 + # + # @abstractmethod def on_websocket_open(self) -> None: """Called when websocket handshake has finished.""" - raise NotImplementedError() # pragma: no cover + pass # pragma: no cover - @abstractmethod + # No longer abstract since v2.4.0 + # + # @abstractmethod def on_websocket_message(self, frame: WebsocketFrame) -> None: """Handle websocket frame.""" - raise NotImplementedError() # pragma: no cover + return None # pragma: no cover + + # Deprecated since v2.4.0 + # + # Instead use on_client_connection_close. + # + # This callback is no longer invoked. Kindly + # update your plugin before upgrading to v2.4.0. + # + # @abstractmethod + # def on_websocket_close(self) -> None: + # """Called when websocket connection has been closed.""" + # raise NotImplementedError() # pragma: no cover + + def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Use this method to override default access log format (see + DEFAULT_WEB_ACCESS_LOG_FORMAT) or to add/update/modify passed context + for usage by default access logger. + + Return updated log context to use for default logging format, OR + Return None if plugin has logged the request. + """ + return context + + +class ReverseProxyBasePlugin(ABC): + """ReverseProxy base plugin class.""" + + def __init__( + self, + uid: str, + flags: argparse.Namespace, + client: HttpClientConnection, + event_queue: EventQueue, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, + ): + self.uid = uid + self.flags = flags + self.client = client + self.event_queue = event_queue + self.upstream_conn_pool = upstream_conn_pool @abstractmethod - def on_websocket_close(self) -> None: - """Called when websocket connection has been closed.""" + def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]: + """List of routes registered by plugin. + + There are 2 types of routes: + + 1) Dynamic routes (str): Should be a regular expression + 2) Static routes (tuple): Contain 2 elements, a route regular expression + and list of upstream urls to serve when the route matches. + + Static routes doesn't require you to implement the `handle_route` method. + Reverse proxy core will automatically pick one of the configured upstream URL + and serve it out-of-box. + + Dynamic routes are helpful when you want to dynamically match and serve upstream urls. + To handle dynamic routes, you must implement the `handle_route` method, which + must return the url to serve.""" raise NotImplementedError() # pragma: no cover + + def before_routing(self, request: HttpParser) -> Optional[HttpParser]: + """Plugins can modify request, return response, close connection. + + If None is returned, request will be dropped and closed.""" + return request # pragma: no cover + + def handle_route(self, request: HttpParser, pattern: RePattern) -> Url: + """Implement this method if you have configured dynamic routes.""" + pass + + def regexes(self) -> List[str]: + """Helper method to return list of route regular expressions.""" + routes = [] + for route in self.routes(): + if isinstance(route, str): + routes.append(route) + elif isinstance(route, tuple): + routes.append(route[0]) + else: + raise ValueError('Invalid route type') + return routes diff --git a/proxy/http/server/protocols.py b/proxy/http/server/protocols.py index e2a99ae9e..84b5d8ac7 100644 --- a/proxy/http/server/protocols.py +++ b/proxy/http/server/protocols.py @@ -7,12 +7,21 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable """ from typing import NamedTuple -HttpProtocolTypes = NamedTuple('HttpProtocolTypes', [ - ('HTTP', int), - ('HTTPS', int), - ('WEBSOCKET', int), -]) + +HttpProtocolTypes = NamedTuple( + 'HttpProtocolTypes', [ + ('HTTP', int), + ('HTTPS', int), + ('WEBSOCKET', int), + ], +) + httpProtocolTypes = HttpProtocolTypes(1, 2, 3) diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index b1b9475e7..ff75d3b32 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -8,229 +8,280 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import gzip import re import time -import logging -import os -import mimetypes import socket -from typing import List, Tuple, Optional, Dict, Union, Any, Pattern +import logging +from typing import Any, Dict, List, Tuple, Union, Pattern, Optional from .plugin import HttpWebServerBasePlugin +from ..parser import HttpParser, httpParserTypes +from ..plugin import HttpProtocolHandlerPlugin from .protocols import httpProtocolTypes from ..exception import HttpProtocolException +from ..protocols import httpProtocols +from ..responses import NOT_FOUND_RESPONSE_PKT from ..websocket import WebsocketFrame, websocketOpcodes -from ..codes import httpStatusCodes -from ..parser import HttpParser, httpParserStates, httpParserTypes -from ..handler import HttpProtocolHandlerPlugin +from ...common.flag import flags +from ...common.types import Readables, Writables, Descriptors +from ...common.utils import text_, build_websocket_handshake_response +from ...common.constants import ( + DEFAULT_ENABLE_WEB_SERVER, DEFAULT_STATIC_SERVER_DIR, + DEFAULT_ENABLE_STATIC_SERVER, + DEFAULT_WEB_ACCESS_LOG_FORMAT, DEFAULT_MIN_COMPRESSION_LENGTH, +) -from ...common.utils import bytes_, text_, build_http_response, build_websocket_handshake_response -from ...common.constants import PROXY_AGENT_HEADER_VALUE -from ...common.types import HasFileno logger = logging.getLogger(__name__) -class HttpWebServerPlugin(HttpProtocolHandlerPlugin): - """HttpProtocolHandler plugin which handles incoming requests to local web server.""" +flags.add_argument( + '--enable-web-server', + action='store_true', + default=DEFAULT_ENABLE_WEB_SERVER, + help='Default: False. Whether to enable proxy.HttpWebServerPlugin.', +) - DEFAULT_404_RESPONSE = memoryview(build_http_response( - httpStatusCodes.NOT_FOUND, - reason=b'NOT FOUND', - headers={b'Server': PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close'} - )) +flags.add_argument( + '--enable-static-server', + action='store_true', + default=DEFAULT_ENABLE_STATIC_SERVER, + help='Default: False. Enable inbuilt static file server. ' + 'Optionally, also use --static-server-dir to serve static content ' + 'from custom directory. By default, static file server serves ' + 'out of installed proxy.py python module folder.', +) - DEFAULT_501_RESPONSE = memoryview(build_http_response( - httpStatusCodes.NOT_IMPLEMENTED, - reason=b'NOT IMPLEMENTED', - headers={b'Server': PROXY_AGENT_HEADER_VALUE, - b'Connection': b'close'} - )) +flags.add_argument( + '--static-server-dir', + type=str, + default=DEFAULT_STATIC_SERVER_DIR, + help='Default: "public" folder in directory where proxy.py is placed. ' + 'This option is only applicable when static server is also enabled. ' + 'See --enable-static-server.', +) + +flags.add_argument( + '--min-compression-length', + type=int, + default=DEFAULT_MIN_COMPRESSION_LENGTH, + help='Default: ' + str(DEFAULT_MIN_COMPRESSION_LENGTH) + ' bytes. ' + + 'Sets the minimum length of a response that will be compressed (gzipped).', +) + + +class HttpWebServerPlugin(HttpProtocolHandlerPlugin): + """HttpProtocolHandler plugin which handles incoming requests to local web server.""" def __init__( self, - *args: Any, **kwargs: Any) -> None: + *args: Any, **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.start_time: float = time.time() self.pipeline_request: Optional[HttpParser] = None self.switched_protocol: Optional[int] = None - self.routes: Dict[int, Dict[Pattern[str], HttpWebServerBasePlugin]] = { + self.route: Optional[HttpWebServerBasePlugin] = None + + self.plugins: Dict[str, HttpWebServerBasePlugin] = {} + self.routes: Dict[ + int, Dict[Pattern[str], HttpWebServerBasePlugin], + ] = { httpProtocolTypes.HTTP: {}, httpProtocolTypes.HTTPS: {}, httpProtocolTypes.WEBSOCKET: {}, } - self.route: Optional[HttpWebServerBasePlugin] = None - if b'HttpWebServerBasePlugin' in self.flags.plugins: - for klass in self.flags.plugins[b'HttpWebServerBasePlugin']: - instance = klass( - self.uid, - self.flags, - self.client, - self.event_queue) - for (protocol, route) in instance.routes(): - self.routes[protocol][re.compile(route)] = instance + self._initialize_web_plugins() @staticmethod - def read_and_build_static_file_response(path: str) -> memoryview: - with open(path, 'rb') as f: - content = f.read() - content_type = mimetypes.guess_type(path)[0] - if content_type is None: - content_type = 'text/plain' - return memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'OK', - headers={ - b'Content-Type': bytes_(content_type), - b'Cache-Control': b'max-age=86400', - b'Content-Encoding': b'gzip', - b'Connection': b'close', - }, - body=gzip.compress(content))) - - def serve_file_or_404(self, path: str) -> bool: - """Read and serves a file from disk. - - Queues 404 Not Found for IOError. - Shouldn't this be server error? - """ - try: - self.client.queue( - self.read_and_build_static_file_response(path)) - except IOError: - self.client.queue(self.DEFAULT_404_RESPONSE) - return True - - def try_upgrade(self) -> bool: - if self.request.has_header(b'connection') and \ - self.request.header(b'connection').lower() == b'upgrade': - if self.request.has_header(b'upgrade') and \ - self.request.header(b'upgrade').lower() == b'websocket': - self.client.queue( - memoryview(build_websocket_handshake_response( - WebsocketFrame.key_to_accept( - self.request.header(b'Sec-WebSocket-Key'))))) - self.switched_protocol = httpProtocolTypes.WEBSOCKET - else: - self.client.queue(self.DEFAULT_501_RESPONSE) - return True - return False - - def on_request_complete(self) -> Union[socket.socket, bool]: - if self.request.has_upstream_server(): - return False - - assert self.request.path - - # If a websocket route exists for the path, try upgrade - for route in self.routes[httpProtocolTypes.WEBSOCKET]: - match = route.match(text_(self.request.path)) - if match: - self.route = self.routes[httpProtocolTypes.WEBSOCKET][route] + def protocols() -> List[int]: + return [httpProtocols.WEB_SERVER] - # Connection upgrade - teardown = self.try_upgrade() - if teardown: - return True + def _initialize_web_plugins(self) -> None: + for klass in self.flags.plugins[b'HttpWebServerBasePlugin']: + instance: HttpWebServerBasePlugin = klass( + self.uid, + self.flags, + self.client, + self.event_queue, + self.upstream_conn_pool, + ) + self.plugins[instance.name()] = instance + for (protocol, route) in instance.routes(): + pattern = re.compile(route) + self.routes[protocol][pattern] = self.plugins[instance.name()] - # For upgraded connections, nothing more to do - if self.switched_protocol: - # Invoke plugin.on_websocket_open - self.route.on_websocket_open() - return False + def encryption_enabled(self) -> bool: + return self.flags.keyfile is not None and \ + self.flags.certfile is not None - break - - # Routing for Http(s) requests - protocol = httpProtocolTypes.HTTPS \ - if self.flags.encryption_enabled() else \ - httpProtocolTypes.HTTP - for route in self.routes[protocol]: - match = route.match(text_(self.request.path)) - if match: - self.route = self.routes[protocol][route] - self.route.handle_request(self.request) - return False + def switch_to_websocket(self) -> None: + self.client.queue( + memoryview( + build_websocket_handshake_response( + WebsocketFrame.key_to_accept( + self.request.header(b'Sec-WebSocket-Key'), + ), + ), + ), + ) + self.switched_protocol = httpProtocolTypes.WEBSOCKET + def on_request_complete(self) -> Union[socket.socket, bool]: + path = self.request.path or b'/' + teardown = self._try_route(path) + # Try route signaled to teardown + # or if it did find a valid route + if teardown or self.route is not None: + return teardown # No-route found, try static serving if enabled if self.flags.enable_static_server: - path = text_(self.request.path).split('?')[0] - if os.path.isfile(self.flags.static_server_dir + path): - return self.serve_file_or_404( - self.flags.static_server_dir + path) - + self._try_static_or_404(path) + return True # Catch all unhandled web server requests, return 404 - self.client.queue(self.DEFAULT_404_RESPONSE) + self.client.queue(NOT_FOUND_RESPONSE_PKT) return True - def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool: - pass + async def get_descriptors(self) -> Descriptors: + r, w = [], [] + for plugin in self.plugins.values(): + r1, w1 = await plugin.get_descriptors() + r.extend(r1) + w.extend(w1) + return r, w + + async def write_to_descriptors(self, w: Writables) -> bool: + for plugin in self.plugins.values(): + teardown = await plugin.write_to_descriptors(w) + if teardown: + return True + return False - def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool: - pass + async def read_from_descriptors(self, r: Readables) -> bool: + for plugin in self.plugins.values(): + teardown = await plugin.read_from_descriptors(r) + if teardown: + return True + return False - def on_client_data(self, raw: memoryview) -> Optional[memoryview]: + def on_client_data(self, raw: memoryview) -> None: if self.switched_protocol == httpProtocolTypes.WEBSOCKET: - # TODO(abhinavsingh): Remove .tobytes after websocket frame parser - # is memoryview compliant + # TODO(abhinavsingh): Do we really tobytes() here? + # Websocket parser currently doesn't depend on internal + # buffers, due to which it can directly parse out of + # memory views. But how about large payloads scenarios? remaining = raw.tobytes() frame = WebsocketFrame() while remaining != b'': - # TODO: Teardown if invalid protocol exception + # TODO: Tear down if invalid protocol exception remaining = frame.parse(remaining) if frame.opcode == websocketOpcodes.CONNECTION_CLOSE: - logger.warning( - 'Client sent connection close packet') - raise HttpProtocolException() + raise HttpProtocolException( + 'Client sent connection close packet', + ) else: assert self.route self.route.on_websocket_message(frame) frame.reset() - return None + return # If 1st valid request was completed and it's a HTTP/1.1 keep-alive # And only if we have a route, parse pipeline requests - elif self.request.state == httpParserStates.COMPLETE and \ - self.request.is_http_1_1_keep_alive() and \ + if self.request.is_complete and \ + self.request.is_http_1_1_keep_alive and \ self.route is not None: if self.pipeline_request is None: self.pipeline_request = HttpParser( - httpParserTypes.REQUEST_PARSER) - # TODO(abhinavsingh): Remove .tobytes after parser is memoryview - # compliant - self.pipeline_request.parse(raw.tobytes()) - if self.pipeline_request.state == httpParserStates.COMPLETE: + httpParserTypes.REQUEST_PARSER, + ) + self.pipeline_request.parse(raw) + if self.pipeline_request.is_complete: self.route.handle_request(self.pipeline_request) - if not self.pipeline_request.is_http_1_1_keep_alive(): - logger.error( - 'Pipelined request is not keep-alive, will teardown request...') - raise HttpProtocolException() + if not self.pipeline_request.is_http_1_1_keep_alive: + raise HttpProtocolException( + 'Pipelined request is not keep-alive, will tear down request...', + ) self.pipeline_request = None - return raw def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: return chunk def on_client_connection_close(self) -> None: - if self.request.has_upstream_server(): - return - if self.switched_protocol: - # Invoke plugin.on_websocket_close - assert self.route - self.route.on_websocket_close() - self.access_log() - - def access_log(self) -> None: - logger.info( - '%s:%s - %s %s - %.2f ms' % - (self.client.addr[0], - self.client.addr[1], - text_(self.request.method), - text_(self.request.path), - (time.time() - self.start_time) * 1000)) - - def get_descriptors( - self) -> Tuple[List[socket.socket], List[socket.socket]]: - return [], [] + context = { + 'client_ip': None if not self.client.addr else self.client.addr[0], + 'client_port': None if not self.client.addr else self.client.addr[1], + 'connection_time_ms': '%.2f' % ((time.time() - self.start_time) * 1000), + # Request + 'request_method': text_(self.request.method), + 'request_path': text_(self.request.path), + 'request_bytes': self.request.total_size, + 'request_ua': text_(self.request.header(b'user-agent')) + if self.request.has_header(b'user-agent') + else None, + 'request_version': None if not self.request.version else text_(self.request.version), + # Response + # + # TODO: Track and inject web server specific response attributes + # Currently, plugins are allowed to queue raw bytes, because of + # which we'll have to reparse the queued packets to deduce + # several attributes required below. At least for code and + # reason attributes. + # + # 'response_bytes': self.response.total_size, + # 'response_code': text_(self.response.code), + # 'response_reason': text_(self.response.reason), + } + log_handled = False + if self.route: + # May be merge on_client_connection_close and on_access_log??? + # probably by simply deprecating on_client_connection_close in future. + self.route.on_client_connection_close() + ctx = self.route.on_access_log(context) + if ctx is None: + log_handled = True + else: + context = ctx + if not log_handled: + self.access_log(context) + + def access_log(self, context: Dict[str, Any]) -> None: + logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context)) + + @property + def _protocol(self) -> Tuple[bool, int]: + do_ws_upgrade = self.request.is_connection_upgrade and \ + self.request.header(b'upgrade').lower() == b'websocket' + return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \ + if do_ws_upgrade \ + else httpProtocolTypes.HTTPS \ + if self.encryption_enabled() \ + else httpProtocolTypes.HTTP + + def _try_route(self, path: bytes) -> bool: + do_ws_upgrade, protocol = self._protocol + for route in self.routes[protocol]: + if route.match(text_(path)): + self.route = self.routes[protocol][route] + assert self.route + # Optionally, upgrade protocol + if do_ws_upgrade: + self.switch_to_websocket() + assert self.route + # Invoke plugin.on_websocket_open + self.route.on_websocket_open() + else: + # Invoke plugin.handle_request + self.route.handle_request(self.request) + if self.request.has_header(b'connection') and \ + self.request.header(b'connection').lower() == b'close': + return True + return False + + def _try_static_or_404(self, path: bytes) -> None: + path = text_(path).split('?', 1)[0] + self.client.queue( + HttpWebServerBasePlugin.serve_static_file( + self.flags.static_server_dir + path, + self.flags.min_compression_length, + ), + ) diff --git a/proxy/http/url.py b/proxy/http/url.py new file mode 100644 index 000000000..c799efaa1 --- /dev/null +++ b/proxy/http/url.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + url +""" +from typing import List, Tuple, Optional + +from .exception import HttpProtocolException +from ..common.utils import text_ +from ..common.constants import AT, COLON, SLASH, DEFAULT_ALLOWED_URL_SCHEMES + + +class Url: + """``urllib.urlparse`` doesn't work for proxy.py, so we wrote a simple URL. + + Currently, URL only implements what is necessary for HttpParser to work. + """ + + def __init__( + self, + scheme: Optional[bytes] = None, + username: Optional[bytes] = None, + password: Optional[bytes] = None, + hostname: Optional[bytes] = None, + port: Optional[int] = None, + remainder: Optional[bytes] = None, + ) -> None: + self.scheme: Optional[bytes] = scheme + self.username: Optional[bytes] = username + self.password: Optional[bytes] = password + self.hostname: Optional[bytes] = hostname + self.port: Optional[int] = port + self.remainder: Optional[bytes] = remainder + + @property + def has_credentials(self) -> bool: + """Returns true if both username and password components are present.""" + return self.username is not None and self.password is not None + + def __str__(self) -> str: + url = '' + if self.scheme: + url += '{0}://'.format(text_(self.scheme)) + if self.hostname: + url += text_(self.hostname) + if self.port: + url += ':{0}'.format(self.port) + if self.remainder: + url += text_(self.remainder) + return url + + @classmethod + def from_bytes(cls, raw: bytes, allowed_url_schemes: Optional[List[bytes]] = None) -> 'Url': + """A URL within proxy.py core can have several styles, + because proxy.py supports both proxy and web server use cases. + + Example: + For a Web server, url is like ``/`` or ``/get`` or ``/get?key=value`` + For a HTTPS connect tunnel, url is like ``httpbin.org:443`` + For a HTTP proxy request, url is like ``http://httpbin.org/get`` + + proxy.py internally never expects a https scheme in the request line. + But `Url` class provides support for parsing any scheme present in the URLs. + e.g. ftp, icap etc. + + If a url with no scheme is parsed, e.g. ``//host/abc.js``, then scheme + defaults to `http`. + + Further: + 1) URL may contain unicode characters + 2) URL may contain IPv4 and IPv6 format addresses instead of domain names + """ + # SLASH == 47, check if URL starts with single slash but not double slash + starts_with_single_slash = raw[0] == 47 + starts_with_double_slash = starts_with_single_slash and \ + len(raw) >= 2 and \ + raw[1] == 47 + if starts_with_single_slash and \ + not starts_with_double_slash: + return cls(remainder=raw) + scheme = None + rest = None + if not starts_with_double_slash: + # Find scheme + parts = raw.split(b'://', 1) + if len(parts) == 2: + scheme = parts[0] + rest = parts[1] + if scheme not in (allowed_url_schemes or DEFAULT_ALLOWED_URL_SCHEMES): + raise HttpProtocolException( + 'Invalid scheme received in the request line %r' % raw, + ) + else: + rest = raw[len(SLASH + SLASH):] + if scheme is not None or starts_with_double_slash: + assert rest is not None + parts = rest.split(SLASH, 1) + username, password, host, port = Url._parse(parts[0]) + return cls( + scheme=scheme if not starts_with_double_slash else b'http', + username=username, + password=password, + hostname=host, + port=port, + remainder=None if len(parts) == 1 else ( + SLASH + parts[1] + ), + ) + username, password, host, port = Url._parse(raw) + return cls(username=username, password=password, hostname=host, port=port) + + @staticmethod + def _parse(raw: bytes) -> Tuple[ + Optional[bytes], + Optional[bytes], + bytes, + Optional[int], + ]: + split_at = raw.split(AT, 1) + username, password = None, None + if len(split_at) == 2: + username, password = split_at[0].split(COLON) + parts = split_at[-1].split(COLON, 2) + num_parts = len(parts) + port: Optional[int] = None + # No port found + if num_parts == 1: + return username, password, parts[0], None + # Host and port found + if num_parts == 2: + return username, password, COLON.join(parts[:-1]), int(parts[-1]) + # More than a single COLON i.e. IPv6 scenario + try: + # Try to resolve last part as an int port + last_token = parts[-1].split(COLON) + port = int(last_token[-1]) + host = COLON.join(parts[:-1]) + COLON + \ + COLON.join(last_token[:-1]) + except ValueError: + # If unable to convert last part into port, + # treat entire data as host + host, port = raw, None + # patch up invalid ipv6 scenario + rhost = host.decode('utf-8') + if COLON.decode('utf-8') in rhost and \ + rhost[0] != '[' and \ + rhost[-1] != ']': + host = b'[' + host + b']' + return username, password, host, port diff --git a/proxy/http/websocket/__init__.py b/proxy/http/websocket/__init__.py new file mode 100644 index 000000000..a787535ca --- /dev/null +++ b/proxy/http/websocket/__init__.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + Submodules + websocket + Websocket +""" +from .frame import WebsocketFrame, websocketOpcodes +from .client import WebsocketClient +from .plugin import WebSocketTransportBasePlugin + + +__all__ = [ + 'websocketOpcodes', + 'WebsocketFrame', + 'WebsocketClient', + 'WebSocketTransportBasePlugin', +] diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py new file mode 100644 index 000000000..2f61cbab8 --- /dev/null +++ b/proxy/http/websocket/client.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import base64 +import socket +import secrets +import selectors +from typing import Callable, Optional + +from .frame import WebsocketFrame +from ..parser import HttpParser, httpParserTypes +from ...common.types import TcpOrTlsSocket +from ...common.utils import ( + text_, new_socket_connection, build_websocket_handshake_request, +) +from ...core.connection import TcpConnection, tcpConnectionTypes +from ...common.constants import ( + DEFAULT_BUFFER_SIZE, DEFAULT_SELECTOR_SELECT_TIMEOUT, +) + + +class WebsocketClient(TcpConnection): + """Websocket client connection. + + TODO: Make me compatible with the work framework.""" + + def __init__( + self, + hostname: bytes, + port: int, + path: bytes = b'/', + on_message: Optional[Callable[[WebsocketFrame], None]] = None, + ) -> None: + super().__init__(tcpConnectionTypes.CLIENT) + self.hostname: bytes = hostname + self.port: int = port + self.path: bytes = path + self.sock: socket.socket = new_socket_connection( + (socket.gethostbyname(text_(self.hostname)), self.port), + ) + self.on_message: Optional[ + Callable[ + [ + WebsocketFrame, + ], + None, + ] + ] = on_message + self.selector: selectors.DefaultSelector = selectors.DefaultSelector() + + @property + def connection(self) -> TcpOrTlsSocket: + return self.sock + + def handshake(self) -> None: + """Start websocket upgrade & handshake protocol""" + self.upgrade() + self.sock.setblocking(False) + + def upgrade(self) -> None: + """Creates a key and sends websocket handshake packet to upstream. + Receives response from the server and asserts that websocket + accept header is valid in the response.""" + key = base64.b64encode(secrets.token_bytes(16)) + self.sock.send( + build_websocket_handshake_request( + key, + url=self.path, + host=self.hostname, + ), + ) + response = HttpParser(httpParserTypes.RESPONSE_PARSER) + response.parse(memoryview(self.sock.recv(DEFAULT_BUFFER_SIZE))) + accept = response.header(b'Sec-Websocket-Accept') + assert WebsocketFrame.key_to_accept(key) == accept + + def shutdown(self, _data: Optional[bytes] = None) -> None: + """Closes connection with the server.""" + super().close() + + def run_once(self) -> bool: + ev = selectors.EVENT_READ + if self.has_buffer(): + ev |= selectors.EVENT_WRITE + self.selector.register(self.sock.fileno(), ev) + events = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT) + self.selector.unregister(self.sock) + for _, mask in events: + if mask & selectors.EVENT_READ and self.on_message: + # TODO: client recvbuf size flag currently not used here + raw = self.recv() + if raw is None or raw == b'': + self.closed = True + return True + frame = WebsocketFrame() + frame.parse(raw.tobytes()) + self.on_message(frame) + elif mask & selectors.EVENT_WRITE: + # TODO: max sendbuf size flag currently not used here + self.flush() + return False + + def run(self) -> None: + try: + while not self.closed: + if self.run_once(): + break + except KeyboardInterrupt: + pass + finally: + if not self.closed: + self.selector.unregister(self.sock) + try: + self.sock.shutdown(socket.SHUT_WR) + except OSError: + pass + self.sock.close() + self.selector.close() diff --git a/proxy/http/websocket.py b/proxy/http/websocket/frame.py similarity index 52% rename from proxy/http/websocket.py rename to proxy/http/websocket/frame.py index a6eb5a337..af6e0c7e9 100644 --- a/proxy/http/websocket.py +++ b/proxy/http/websocket/frame.py @@ -7,35 +7,33 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + http + iterable + websocket + Websocket """ -import hashlib +import io import base64 -import selectors import struct -import socket -import secrets -import ssl -import ipaddress +import hashlib import logging -import io - -from typing import TypeVar, Type, Optional, NamedTuple, Union, Callable - -from .parser import httpParserTypes, HttpParser - -from ..common.constants import DEFAULT_BUFFER_SIZE -from ..common.utils import new_socket_connection, build_websocket_handshake_request -from ..core.connection import tcpConnectionTypes, TcpConnection - - -WebsocketOpcodes = NamedTuple('WebsocketOpcodes', [ - ('CONTINUATION_FRAME', int), - ('TEXT_FRAME', int), - ('BINARY_FRAME', int), - ('CONNECTION_CLOSE', int), - ('PING', int), - ('PONG', int), -]) +import secrets +from typing import Type, TypeVar, Optional, NamedTuple + + +WebsocketOpcodes = NamedTuple( + 'WebsocketOpcodes', [ + ('CONTINUATION_FRAME', int), + ('TEXT_FRAME', int), + ('BINARY_FRAME', int), + ('CONNECTION_CLOSE', int), + ('PING', int), + ('PONG', int), + ], +) websocketOpcodes = WebsocketOpcodes(0x0, 0x1, 0x2, 0x8, 0x9, 0xA) @@ -101,35 +99,38 @@ def build(self) -> bytes: (1 << 6 if self.rsv1 else 0) | (1 << 5 if self.rsv2 else 0) | (1 << 4 if self.rsv3 else 0) | - self.opcode - )) + self.opcode, + ), + ) assert self.payload_length is not None if self.payload_length < 126: raw.write( struct.pack( '!B', - (1 << 7 if self.masked else 0) | self.payload_length - ) + (1 << 7 if self.masked else 0) | self.payload_length, + ), ) elif self.payload_length < 1 << 16: raw.write( struct.pack( '!BH', (1 << 7 if self.masked else 0) | 126, - self.payload_length - ) + self.payload_length, + ), ) elif self.payload_length < 1 << 64: raw.write( struct.pack( '!BHQ', (1 << 7 if self.masked else 0) | 127, - self.payload_length - ) + self.payload_length, + ), ) else: - raise ValueError(f'Invalid payload_length { self.payload_length },' - f'maximum allowed { 1 << 64 }') + raise ValueError( + f'Invalid payload_length { self.payload_length },' + f'maximum allowed { 1 << 64 }', + ) if self.masked and self.data: mask = secrets.token_bytes(4) if self.mask is None else self.mask raw.write(mask) @@ -177,92 +178,6 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: @staticmethod def key_to_accept(key: bytes) -> bytes: - sha1 = hashlib.sha1() + sha1 = hashlib.sha1() # noqa: S324 sha1.update(key + WebsocketFrame.GUID) return base64.b64encode(sha1.digest()) - - -class WebsocketClient(TcpConnection): - - def __init__(self, - hostname: Union[ipaddress.IPv4Address, ipaddress.IPv6Address], - port: int, - path: bytes = b'/', - on_message: Optional[Callable[[WebsocketFrame], None]] = None) -> None: - super().__init__(tcpConnectionTypes.CLIENT) - self.hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = hostname - self.port: int = port - self.path: bytes = path - self.sock: socket.socket = new_socket_connection( - (str(self.hostname), self.port)) - self.on_message: Optional[Callable[[ - WebsocketFrame], None]] = on_message - self.upgrade() - self.sock.setblocking(False) - self.selector: selectors.DefaultSelector = selectors.DefaultSelector() - - @property - def connection(self) -> Union[ssl.SSLSocket, socket.socket]: - return self.sock - - def upgrade(self) -> None: - key = base64.b64encode(secrets.token_bytes(16)) - self.sock.send(build_websocket_handshake_request(key, url=self.path)) - response = HttpParser(httpParserTypes.RESPONSE_PARSER) - response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE)) - accept = response.header(b'Sec-Websocket-Accept') - assert WebsocketFrame.key_to_accept(key) == accept - - def ping(self, data: Optional[bytes] = None) -> None: - pass - - def pong(self, data: Optional[bytes] = None) -> None: - pass - - def shutdown(self, _data: Optional[bytes] = None) -> None: - """Closes connection with the server.""" - super().close() - - def run_once(self) -> bool: - ev = selectors.EVENT_READ - if self.has_buffer(): - ev |= selectors.EVENT_WRITE - self.selector.register(self.sock.fileno(), ev) - events = self.selector.select(timeout=1) - self.selector.unregister(self.sock) - for _, mask in events: - if mask & selectors.EVENT_READ and self.on_message: - raw = self.recv() - if raw is None or raw.tobytes() == b'': - self.closed = True - logger.debug('Websocket connection closed by server') - return True - frame = WebsocketFrame() - # TODO(abhinavsingh): Remove .tobytes after parser is - # memoryview compliant - frame.parse(raw.tobytes()) - self.on_message(frame) - elif mask & selectors.EVENT_WRITE: - logger.debug(self.buffer) - self.flush() - return False - - def run(self) -> None: - logger.debug('running') - try: - while not self.closed: - teardown = self.run_once() - if teardown: - break - except KeyboardInterrupt: - pass - finally: - try: - self.selector.unregister(self.sock) - self.sock.shutdown(socket.SHUT_WR) - except Exception as e: - logging.exception( - 'Exception while shutdown of websocket client', exc_info=e) - self.sock.close() - logger.info('done') diff --git a/proxy/http/websocket/plugin.py b/proxy/http/websocket/plugin.py new file mode 100644 index 000000000..be544773a --- /dev/null +++ b/proxy/http/websocket/plugin.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import json +import argparse +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List + +from . import WebsocketFrame +from ...core.event import EventQueue +from ...common.utils import bytes_ + + +if TYPE_CHECKING: # pragma: no cover + from ..connection import HttpClientConnection + + +class WebSocketTransportBasePlugin(ABC): + """Abstract class for plugins extending dashboard websocket API.""" + + def __init__( + self, + flags: argparse.Namespace, + client: 'HttpClientConnection', + event_queue: EventQueue, + ) -> None: + self.flags = flags + self.client = client + self.event_queue = event_queue + + @abstractmethod + def methods(self) -> List[str]: + """Return list of methods that this plugin will handle.""" + pass # pragma: no cover + + def connected(self) -> None: + """Invoked when client websocket handshake finishes.""" + pass # pragma: no cover + + @abstractmethod + def handle_message(self, message: Dict[str, Any]) -> None: + """Handle messages for registered methods.""" + pass # pragma: no cover + + def disconnected(self) -> None: + """Invoked when client websocket connection gets closed.""" + pass # pragma: no cover + + def reply(self, data: Dict[str, Any]) -> None: + self.client.queue( + memoryview( + WebsocketFrame.text( + bytes_( + json.dumps(data), + ), + ), + ), + ) diff --git a/proxy/http/websocket/transport.py b/proxy/http/websocket/transport.py new file mode 100644 index 000000000..463195ef5 --- /dev/null +++ b/proxy/http/websocket/transport.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import json +import logging +from typing import Any, Dict, List, Tuple + +from .frame import WebsocketFrame +from .plugin import WebSocketTransportBasePlugin +from ..parser import HttpParser +from ..server import HttpWebServerBasePlugin, httpProtocolTypes +from ...common.utils import bytes_ + + +logger = logging.getLogger(__name__) + + +class WebSocketTransport(HttpWebServerBasePlugin): + """WebSocket transport framework.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.plugins: List[WebSocketTransportBasePlugin] = [] + # Registered methods and handler plugin + self.methods: Dict[str, WebSocketTransportBasePlugin] = {} + if b'WebSocketTransportBasePlugin' in self.flags.plugins: + for klass in self.flags.plugins[b'WebSocketTransportBasePlugin']: + p = klass(self.flags, self.client, self.event_queue) + self.plugins.append(p) + for method in p.methods(): + self.methods[method] = p + + def routes(self) -> List[Tuple[int, str]]: + return [ + (httpProtocolTypes.WEBSOCKET, r'/transport/$'), + ] + + def handle_request(self, request: HttpParser) -> None: + raise NotImplementedError() + + def on_websocket_open(self) -> None: + for plugin in self.plugins: + plugin.connected() + + def on_websocket_message(self, frame: WebsocketFrame) -> None: + try: + assert frame.data + message = json.loads(frame.data) + except UnicodeDecodeError: + logger.error(frame.data) + logger.info(frame.opcode) + return + + method = message['method'] + if method == 'ping': + self.reply({'id': message['id'], 'response': 'pong'}) + elif method in self.methods: + self.methods[method].handle_message(message) + else: + logger.info(frame.data) + logger.info(frame.opcode) + self.reply({'id': message['id'], 'response': 'not_implemented'}) + + def on_client_connection_close(self) -> None: + for plugin in self.plugins: + plugin.disconnected() + + def reply(self, data: Dict[str, Any]) -> None: + self.client.queue( + memoryview( + WebsocketFrame.text( + bytes_( + json.dumps(data), + ), + ), + ), + ) diff --git a/proxy/mempool/neon_tx_send_iterative_strategy.py b/proxy/mempool/neon_tx_send_iterative_strategy.py index a8302346a..d20af3c40 100644 --- a/proxy/mempool/neon_tx_send_iterative_strategy.py +++ b/proxy/mempool/neon_tx_send_iterative_strategy.py @@ -131,7 +131,7 @@ def execute(self) -> NeonTxResultInfo: emulated_step_cnt = max(self._ctx.emulated_evm_step_cnt, self._start_evm_step_cnt) additional_iter_cnt = self._ctx.neon_tx_exec_cfg.resize_iter_cnt - additional_iter_cnt += 5 # begin + finalization + additional_iter_cnt += 2 # begin + finalization tx_list = self.build_tx_list(emulated_step_cnt, additional_iter_cnt) tx_sender = IterativeNeonTxSender(self, self._ctx.solana, self._ctx.signer) tx_sender.send(tx_list) diff --git a/proxy/plugin/__init__.py b/proxy/plugin/__init__.py index cc1ee76c7..0c7724596 100644 --- a/proxy/plugin/__init__.py +++ b/proxy/plugin/__init__.py @@ -7,29 +7,19 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. + + .. spelling:: + + Cloudflare + ws + onmessage + httpbin + localhost + Lua """ -from .cache import CacheResponsesPlugin, BaseCacheResponsesPlugin -from .filter_by_upstream import FilterByUpstreamHostPlugin -from .man_in_the_middle import ManInTheMiddlePlugin -from .mock_rest_api import ProposedRestApiPlugin -from .modify_post_data import ModifyPostDataPlugin -from .redirect_to_custom_server import RedirectToCustomServerPlugin -from .shortlink import ShortLinkPlugin -from .web_server_route import WebServerPlugin -from .reverse_proxy import ReverseProxyPlugin -from .proxy_pool import ProxyPoolPlugin from .neon_rpc_api_plugin import NeonRpcApiPlugin + __all__ = [ - 'CacheResponsesPlugin', - 'BaseCacheResponsesPlugin', - 'FilterByUpstreamHostPlugin', - 'ManInTheMiddlePlugin', - 'ProposedRestApiPlugin', - 'ModifyPostDataPlugin', - 'RedirectToCustomServerPlugin', - 'ShortLinkPlugin', - 'WebServerPlugin', - 'ReverseProxyPlugin', - 'ProxyPoolPlugin', + 'NeonRpcApiPlugin', ] diff --git a/proxy/plugin/cache/base.py b/proxy/plugin/cache/base.py deleted file mode 100644 index 81a2ef65f..000000000 --- a/proxy/plugin/cache/base.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import logging -from typing import Optional, Any - -from ...http.parser import HttpParser -from ...http.proxy import HttpProxyBasePlugin -from .store.base import CacheStore - -logger = logging.getLogger(__name__) - - -class BaseCacheResponsesPlugin(HttpProxyBasePlugin): - """Base cache plugin. - - It requires a storage backend to work with. Storage class - must implement CacheStore interface. - - Different storage backends can be used per request if required. - """ - - def __init__( - self, - *args: Any, - **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.store: Optional[CacheStore] = None - - def set_store(self, store: CacheStore) -> None: - self.store = store - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - assert self.store - try: - self.store.open(request) - except Exception as e: - logger.info('Caching disabled due to exception message %s', str(e)) - return request - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - assert self.store - return self.store.cache_request(request) - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - assert self.store - return self.store.cache_response_chunk(chunk) - - def on_upstream_connection_close(self) -> None: - assert self.store - self.store.close() diff --git a/proxy/plugin/cache/cache_responses.py b/proxy/plugin/cache/cache_responses.py deleted file mode 100644 index 91f290790..000000000 --- a/proxy/plugin/cache/cache_responses.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import multiprocessing -import tempfile -from typing import Any - -from .store.disk import OnDiskCacheStore -from .base import BaseCacheResponsesPlugin - - -class CacheResponsesPlugin(BaseCacheResponsesPlugin): - """Caches response using OnDiskCacheStore.""" - - # Dynamically enable / disable cache - ENABLED = multiprocessing.Event() - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.disk_store = OnDiskCacheStore( - uid=self.uid, cache_dir=tempfile.gettempdir()) - self.set_store(self.disk_store) diff --git a/proxy/plugin/cache/store/base.py b/proxy/plugin/cache/store/base.py deleted file mode 100644 index eafeaa3c4..000000000 --- a/proxy/plugin/cache/store/base.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from abc import ABC, abstractmethod -from typing import Optional -from uuid import UUID -from ....http.parser import HttpParser - - -class CacheStore(ABC): - - def __init__(self, uid: UUID) -> None: - self.uid = uid - - @abstractmethod - def open(self, request: HttpParser) -> None: - pass - - @abstractmethod - def cache_request(self, request: HttpParser) -> Optional[HttpParser]: - return request - - @abstractmethod - def cache_response_chunk(self, chunk: memoryview) -> memoryview: - return chunk - - @abstractmethod - def close(self) -> None: - pass diff --git a/proxy/plugin/cache/store/disk.py b/proxy/plugin/cache/store/disk.py deleted file mode 100644 index 91eb4f295..000000000 --- a/proxy/plugin/cache/store/disk.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import logging -import os -from typing import Optional, BinaryIO -from uuid import UUID - -from ....common.utils import text_ -from ....http.parser import HttpParser - -from .base import CacheStore - -logger = logging.getLogger(__name__) - - -class OnDiskCacheStore(CacheStore): - - def __init__(self, uid: UUID, cache_dir: str) -> None: - super().__init__(uid) - self.cache_dir = cache_dir - self.cache_file_path: Optional[str] = None - self.cache_file: Optional[BinaryIO] = None - - def open(self, request: HttpParser) -> None: - self.cache_file_path = os.path.join( - self.cache_dir, - '%s-%s.txt' % (text_(request.host), self.uid.hex)) - self.cache_file = open(self.cache_file_path, "wb") - - def cache_request(self, request: HttpParser) -> Optional[HttpParser]: - return request - - def cache_response_chunk(self, chunk: memoryview) -> memoryview: - if self.cache_file: - self.cache_file.write(chunk.tobytes()) - return chunk - - def close(self) -> None: - if self.cache_file: - self.cache_file.close() - logger.info('Cached response at %s', self.cache_file_path) diff --git a/proxy/plugin/filter_by_upstream.py b/proxy/plugin/filter_by_upstream.py deleted file mode 100644 index a919bd15b..000000000 --- a/proxy/plugin/filter_by_upstream.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from typing import Optional - -from ..http.exception import HttpRequestRejected -from ..http.parser import HttpParser -from ..http.codes import httpStatusCodes -from ..http.proxy import HttpProxyBasePlugin - - -class FilterByUpstreamHostPlugin(HttpProxyBasePlugin): - """Drop traffic by inspecting upstream host.""" - - FILTERED_DOMAINS = [b'google.com', b'www.google.com'] - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - if request.host in self.FILTERED_DOMAINS: - raise HttpRequestRejected( - status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', - headers={ - b'Connection': b'close', - } - ) - return request - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - return request - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return chunk - - def on_upstream_connection_close(self) -> None: - pass diff --git a/proxy/plugin/man_in_the_middle.py b/proxy/plugin/man_in_the_middle.py deleted file mode 100644 index cc3ab63e7..000000000 --- a/proxy/plugin/man_in_the_middle.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from typing import Optional - -from ..common.utils import build_http_response -from ..http.parser import HttpParser -from ..http.codes import httpStatusCodes -from ..http.proxy import HttpProxyBasePlugin - - -class ManInTheMiddlePlugin(HttpProxyBasePlugin): - """Modifies upstream server responses.""" - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - return request - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - return request - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'OK', body=b'Hello from man in the middle')) - - def on_upstream_connection_close(self) -> None: - pass diff --git a/proxy/plugin/mock_rest_api.py b/proxy/plugin/mock_rest_api.py deleted file mode 100644 index 270c86440..000000000 --- a/proxy/plugin/mock_rest_api.py +++ /dev/null @@ -1,88 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import json -from typing import Optional - -from ..common.utils import bytes_, build_http_response, text_ -from ..http.parser import HttpParser -from ..http.proxy import HttpProxyBasePlugin -from ..http.codes import httpStatusCodes - - -class ProposedRestApiPlugin(HttpProxyBasePlugin): - """Mock responses for your upstream REST API. - - Used to test and develop client side applications - without need of an actual upstream REST API server. - - Returns proposed REST API mock responses to the client - without establishing upstream connection. - - Note: This plugin won't work if your client is making - HTTPS connection to api.example.com. - """ - - API_SERVER = b'api.example.com' - - REST_API_SPEC = { - b'/v1/users/': { - 'count': 2, - 'next': None, - 'previous': None, - 'results': [ - { - 'email': 'you@example.com', - 'groups': [], - 'url': text_(API_SERVER) + '/v1/users/1/', - 'username': 'admin', - }, - { - 'email': 'someone@example.com', - 'groups': [], - 'url': text_(API_SERVER) + '/v1/users/2/', - 'username': 'someone', - }, - ] - }, - } - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - # Return None to disable establishing connection to upstream - # Most likely our api.example.com won't even exist under development - # scenario - return None - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - if request.host != self.API_SERVER: - return request - assert request.path - if request.path in self.REST_API_SPEC: - self.client.queue(memoryview(build_http_response( - httpStatusCodes.OK, - reason=b'OK', - headers={b'Content-Type': b'application/json'}, - body=bytes_(json.dumps( - self.REST_API_SPEC[request.path])) - ))) - else: - self.client.queue(memoryview(build_http_response( - httpStatusCodes.NOT_FOUND, - reason=b'NOT FOUND', body=b'Not Found' - ))) - return None - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return chunk - - def on_upstream_connection_close(self) -> None: - pass diff --git a/proxy/plugin/modify_post_data.py b/proxy/plugin/modify_post_data.py deleted file mode 100644 index 98b89daf5..000000000 --- a/proxy/plugin/modify_post_data.py +++ /dev/null @@ -1,47 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from typing import Optional - -from ..common.utils import bytes_ -from ..http.parser import HttpParser -from ..http.proxy import HttpProxyBasePlugin -from ..http.methods import httpMethods - - -class ModifyPostDataPlugin(HttpProxyBasePlugin): - """Modify POST request body before sending to upstream server.""" - - MODIFIED_BODY = b'{"key": "modified"}' - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - return request - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - if request.method == httpMethods.POST: - request.body = ModifyPostDataPlugin.MODIFIED_BODY - # Update Content-Length header only when request is NOT chunked - # encoded - if not request.is_chunked_encoded(): - request.add_header(b'Content-Length', - bytes_(len(request.body))) - # Enforce content-type json - if request.has_header(b'Content-Type'): - request.del_header(b'Content-Type') - request.add_header(b'Content-Type', b'application/json') - return request - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return chunk - - def on_upstream_connection_close(self) -> None: - pass diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py deleted file mode 100644 index 3cb664c70..000000000 --- a/proxy/plugin/proxy_pool.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import random -import socket -from typing import Optional, Any - -from ..common.constants import DEFAULT_BUFFER_SIZE, SLASH, COLON -from ..common.utils import new_socket_connection -from ..http.proxy import HttpProxyBasePlugin -from ..http.parser import HttpParser - - -class ProxyPoolPlugin(HttpProxyBasePlugin): - """Proxy incoming client proxy requests through a set of upstream proxies.""" - - # Run two separate instances of proxy.py - # on port 9000 and 9001 BUT WITHOUT ProxyPool plugin - # to avoid infinite loops. - UPSTREAM_PROXY_POOL = [ - ('localhost', 9000), - ('localhost', 9001), - ] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.conn: Optional[socket.socket] = None - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - """Avoid upstream connection of server in the request. - Initialize, connection to upstream proxy. - """ - # Implement your own logic here e.g. round-robin, least connection etc. - self.conn = new_socket_connection( - random.choice(self.UPSTREAM_PROXY_POOL)) - return None - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - request.path = self.rebuild_original_path(request) - self.tunnel(request) - # Returning None indicates core to gracefully - # flush client buffer and teardown the connection - return None - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - """Will never be called since we didn't establish an upstream connection.""" - return chunk - - def on_upstream_connection_close(self) -> None: - """Will never be called since we didn't establish an upstream connection.""" - pass - - def tunnel(self, request: HttpParser) -> None: - """Send to upstream proxy, receive from upstream proxy, queue back to client.""" - assert self.conn - self.conn.send(request.build()) - response = self.conn.recv(DEFAULT_BUFFER_SIZE) - self.client.queue(memoryview(response)) - - @staticmethod - def rebuild_original_path(request: HttpParser) -> bytes: - """Re-builds original upstream server URL. - - proxy server core by default strips upstream host:port - from incoming client proxy request. - """ - assert request.url and request.host and request.port and request.path - return ( - request.url.scheme + - COLON + SLASH + SLASH + - request.host + - COLON + - str(request.port).encode() + - request.path - ) diff --git a/proxy/plugin/redirect_to_custom_server.py b/proxy/plugin/redirect_to_custom_server.py deleted file mode 100644 index 1e1fe8830..000000000 --- a/proxy/plugin/redirect_to_custom_server.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from urllib import parse as urlparse -from typing import Optional - -from ..http.proxy import HttpProxyBasePlugin -from ..http.parser import HttpParser -from ..http.methods import httpMethods - - -class RedirectToCustomServerPlugin(HttpProxyBasePlugin): - """Modifies client request to redirect all incoming requests to a fixed server address.""" - - UPSTREAM_SERVER = b'http://localhost:8545/' - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - # Redirect all non-https requests to inbuilt WebServer. - if request.method != httpMethods.CONNECT: - request.set_url(self.UPSTREAM_SERVER) - # Update Host header too, otherwise upstream can reject our request - if request.has_header(b'Host'): - request.del_header(b'Host') - request.add_header( - b'Host', urlparse.urlsplit( - self.UPSTREAM_SERVER).netloc) - return request - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - return request - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return chunk - - def on_upstream_connection_close(self) -> None: - pass diff --git a/proxy/plugin/reverse_proxy.py b/proxy/plugin/reverse_proxy.py deleted file mode 100644 index 57be6b87a..000000000 --- a/proxy/plugin/reverse_proxy.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import random -from typing import List, Tuple -from urllib import parse as urlparse - -from ..common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_HTTP_PORT -from ..common.utils import socket_connection, text_ -from ..http.parser import HttpParser -from ..http.websocket import WebsocketFrame -from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes - - -class ReverseProxyPlugin(HttpWebServerBasePlugin): - """Extend in-built Web Server to add Reverse Proxy capabilities. - - This example plugin is equivalent to following Nginx configuration: - - location /get { - proxy_pass http://httpbin.org/get - } - - Example: - - $ curl http://localhost:9000/get - { - "args": {}, - "headers": { - "Accept": "*/*", - "Host": "localhost", - "User-Agent": "curl/7.64.1" - }, - "origin": "1.2.3.4, 5.6.7.8", - "url": "https://localhost/get" - } - """ - - REVERSE_PROXY_LOCATION: str = r'/api$' - REVERSE_PROXY_PASS = [ - b'http://localhost:8545/' - ] - - def routes(self) -> List[Tuple[int, str]]: - return [ - (httpProtocolTypes.HTTP, ReverseProxyPlugin.REVERSE_PROXY_LOCATION), - (httpProtocolTypes.HTTPS, ReverseProxyPlugin.REVERSE_PROXY_LOCATION) - ] - - def handle_request(self, request: HttpParser) -> None: - upstream = random.choice(ReverseProxyPlugin.REVERSE_PROXY_PASS) - url = urlparse.urlsplit(upstream) - print(request.body) - assert url.hostname - with socket_connection((text_(url.hostname), url.port if url.port else DEFAULT_HTTP_PORT)) as conn: - conn.send(request.build()) - raw = memoryview(conn.recv(DEFAULT_BUFFER_SIZE)) - response = HttpParser.response(raw) - print(response.body) - self.client.queue(raw) - - def on_websocket_open(self) -> None: - pass - - def on_websocket_message(self, frame: WebsocketFrame) -> None: - pass - - def on_websocket_close(self) -> None: - pass diff --git a/proxy/plugin/shortlink.py b/proxy/plugin/shortlink.py deleted file mode 100644 index 309fc1fbc..000000000 --- a/proxy/plugin/shortlink.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -from typing import Optional - -from ..common.constants import DOT, SLASH -from ..common.utils import build_http_response -from ..http.parser import HttpParser -from ..http.codes import httpStatusCodes -from ..http.proxy import HttpProxyBasePlugin - - -class ShortLinkPlugin(HttpProxyBasePlugin): - """Add support for short links in your favorite browsers / applications. - - Enable ShortLinkPlugin and speed up your daily browsing experience. - - Example: - * f/ for facebook.com - * g/ for google.com - * t/ for twitter.com - * y/ for youtube.com - * proxy/ for py internal web servers. - Customize map below for your taste and need. - - Paths are also preserved. E.g. t/imoracle will - resolve to http://twitter.com/imoracle. - """ - - SHORT_LINKS = { - b'a': b'amazon.com', - b'i': b'instagram.com', - b'l': b'linkedin.com', - b'f': b'facebook.com', - b'g': b'google.com', - b't': b'twitter.com', - b'w': b'web.whatsapp.com', - b'y': b'youtube.com', - b'proxy': b'localhost:8899', - } - - def before_upstream_connection( - self, request: HttpParser) -> Optional[HttpParser]: - if request.host and request.host != b'localhost' and DOT not in request.host: - # Avoid connecting to upstream - return None - return request - - def handle_client_request( - self, request: HttpParser) -> Optional[HttpParser]: - if request.host and request.host != b'localhost' and DOT not in request.host: - if request.host in self.SHORT_LINKS: - path = SLASH if not request.path else request.path - self.client.queue(memoryview(build_http_response( - httpStatusCodes.SEE_OTHER, reason=b'See Other', - headers={ - b'Location': b'http://' + self.SHORT_LINKS[request.host] + path, - b'Content-Length': b'0', - b'Connection': b'close', - } - ))) - else: - self.client.queue(memoryview(build_http_response( - httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND', - headers={ - b'Content-Length': b'0', - b'Connection': b'close', - } - ))) - return None - return request - - def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: - return chunk - - def on_upstream_connection_close(self) -> None: - pass diff --git a/proxy/plugin/web_server_route.py b/proxy/plugin/web_server_route.py deleted file mode 100644 index c8b4731a4..000000000 --- a/proxy/plugin/web_server_route.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import logging -from typing import List, Tuple - -from ..common.utils import build_http_response -from ..http.parser import HttpParser -from ..http.codes import httpStatusCodes -from ..http.websocket import WebsocketFrame -from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes - -logger = logging.getLogger(__name__) - - -class WebServerPlugin(HttpWebServerBasePlugin): - """Demonstrates inbuilt web server routing using plugin.""" - - def routes(self) -> List[Tuple[int, str]]: - return [ - (httpProtocolTypes.HTTP, r'/http-route-example$'), - (httpProtocolTypes.HTTPS, r'/https-route-example$'), - (httpProtocolTypes.WEBSOCKET, r'/ws-route-example$'), - ] - - def handle_request(self, request: HttpParser) -> None: - if request.path == b'/http-route-example': - self.client.queue(memoryview(build_http_response( - httpStatusCodes.OK, body=b'HTTP route response'))) - elif request.path == b'/https-route-example': - self.client.queue(memoryview(build_http_response( - httpStatusCodes.OK, body=b'HTTPS route response'))) - - def on_websocket_open(self) -> None: - logger.info('Websocket open') - - def on_websocket_message(self, frame: WebsocketFrame) -> None: - logger.info(frame.data) - - def on_websocket_close(self) -> None: - logger.info('Websocket close') diff --git a/proxy/proxy.py b/proxy/proxy.py index 55bd3f968..30b9ee89e 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -8,83 +8,323 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import contextlib import os import sys import time +import pprint +import signal +import logging +import threading +from typing import TYPE_CHECKING, Any, List, Optional, cast -from types import TracebackType -from typing import List, Optional, Generator, Any, Type - +from .core.ssh import SshTunnelListener, SshHttpProtocolHandler +from .core.work import ThreadlessPool +from .core.event import EventManager +from .common.flag import FlagParser, flags from .common.utils import bytes_ -from .common.flags import Flags +from .core.work.fd import RemoteFdExecutor from .core.acceptor import AcceptorPool -from .http.handler import HttpProtocolHandler +from .core.listener import ListenerPool +from .common.constants import ( + IS_WINDOWS, DEFAULT_PLUGINS, DEFAULT_VERSION, DEFAULT_LOG_FILE, + DEFAULT_PID_FILE, DEFAULT_LOG_LEVEL, DEFAULT_BASIC_AUTH, + DEFAULT_LOG_FORMAT, DEFAULT_WORK_KLASS, DEFAULT_OPEN_FILE_LIMIT, + DEFAULT_ENABLE_SSH_TUNNEL, +) + + +if TYPE_CHECKING: # pragma: no cover + from .core.listener import TcpSocketListener + + +logger = logging.getLogger(__name__) + + +flags.add_argument( + '--version', + '-v', + action='store_true', + default=DEFAULT_VERSION, + help='Prints proxy.py version.', +) + +# TODO: Add --verbose option which also +# starts to log traffic flowing between +# clients and upstream servers. +flags.add_argument( + '--log-level', + type=str, + default=DEFAULT_LOG_LEVEL, + help='Valid options: DEBUG, INFO (default), WARNING, ERROR, CRITICAL. ' + 'Both upper and lowercase values are allowed. ' + 'You may also simply use the leading character e.g. --log-level d', +) + +flags.add_argument( + '--log-file', + type=str, + default=DEFAULT_LOG_FILE, + help='Default: sys.stdout. Log file destination.', +) -from multiprocessing import Process +flags.add_argument( + '--log-format', + type=str, + default=DEFAULT_LOG_FORMAT, + help='Log format for Python logger.', +) + +flags.add_argument( + '--open-file-limit', + type=int, + default=DEFAULT_OPEN_FILE_LIMIT, + help='Default: 1024. Maximum number of files (TCP connections) ' + 'that proxy.py can open concurrently.', +) + +flags.add_argument( + '--plugins', + action='append', + nargs='+', + default=DEFAULT_PLUGINS, + help='Comma separated plugins. ' + + 'You may use --plugins flag multiple times.', +) + +flags.add_argument( + '--enable-ssh-tunnel', + action='store_true', + default=DEFAULT_ENABLE_SSH_TUNNEL, + help='Default: False. Enable SSH tunnel.', +) + +flags.add_argument( + '--work-klass', + type=str, + default=DEFAULT_WORK_KLASS, + help='Default: ' + DEFAULT_WORK_KLASS + + '. Work klass to use for work execution.', +) + +flags.add_argument( + '--pid-file', + type=str, + default=DEFAULT_PID_FILE, + help='Default: None. Save "parent" process ID to a file.', +) + +flags.add_argument( + '--openssl', + type=str, + default='openssl', + help='Default: openssl. Path to openssl binary. ' + + 'By default, assumption is that openssl is in your PATH.', +) class Proxy: + """Proxy is a context manager to control proxy.py library core. - def __init__(self, input_args: Optional[List[str]], **opts: Any) -> None: - self.flags = Flags.initialize(input_args, **opts) - self.acceptors: Optional[AcceptorPool] = None - self.indexer: Optional[Process] = None + By default, :class:`~proxy.core.pool.AcceptorPool` is started with + :class:`~proxy.http.handler.HttpProtocolHandler` work class. + By definition, it expects HTTP traffic to flow between clients and server. - def write_pid_file(self) -> None: - if self.flags.pid_file is not None: - with open(self.flags.pid_file, 'wb') as pid_file: - pid_file.write(bytes_(os.getpid())) + In ``--threadless`` mode and without ``--local-executor``, + a :class:`~proxy.core.executors.ThreadlessPool` is also started. + Executor pool receives newly accepted work by :class:`~proxy.core.acceptor.Acceptor` + and creates an instance of work class for processing the received work. - def delete_pid_file(self) -> None: - if self.flags.pid_file and os.path.exists(self.flags.pid_file): - os.remove(self.flags.pid_file) + In ``--threadless`` mode and with ``--local-executor 0``, + acceptors will start a companion thread to handle accepted + client connections. + + Optionally, Proxy class also initializes the EventManager. + A multi-process safe pubsub system which can be used to build various + patterns for message sharing and/or signaling. + """ + + def __init__(self, input_args: Optional[List[str]] = None, **opts: Any) -> None: + self.flags = FlagParser.initialize(input_args, **opts) + self.listeners: Optional[ListenerPool] = None + self.executors: Optional[ThreadlessPool] = None + self.acceptors: Optional[AcceptorPool] = None + self.event_manager: Optional[EventManager] = None + self.ssh_http_protocol_handler: Optional[SshHttpProtocolHandler] = None + self.ssh_tunnel_listener: Optional[SshTunnelListener] = None def __enter__(self) -> 'Proxy': + self.setup() + return self + + def __exit__(self, *args: Any) -> None: + self.shutdown() + + def setup(self) -> None: + # TODO: Introduce cron feature + # https://github.com/abhinavsingh/proxy.py/discussions/808 + # + # TODO: Introduce ability to change flags dynamically + # https://github.com/abhinavsingh/proxy.py/discussions/1020 + # + # TODO: Python shell within running proxy.py environment + # https://github.com/abhinavsingh/proxy.py/discussions/1021 + # + # TODO: Near realtime resource / stats monitoring + # https://github.com/abhinavsingh/proxy.py/discussions/1023 + # + self._write_pid_file() + # We setup listeners first because of flags.port override + # in case of ephemeral port being used + self.listeners = ListenerPool(flags=self.flags) + self.listeners.setup() + # Override flags.port to match the actual port + # we are listening upon. This is necessary to preserve + # the server port when `--port=0` is used. + if not self.flags.unix_socket_path: + self.flags.port = cast( + 'TcpSocketListener', + self.listeners.pool[0], + )._port + # --ports flag can also use 0 as value for ephemeral port selection. + # Here, we override flags.ports to reflect actual listening ports. + ports = [] + offset = 1 if self.flags.unix_socket_path or self.flags.port else 0 + for index in range(offset, offset + len(self.flags.ports)): + ports.append( + cast( + 'TcpSocketListener', + self.listeners.pool[index], + )._port, + ) + self.flags.ports = ports + # Write ports to port file + self._write_port_file() + # Setup EventManager + if self.flags.enable_events: + logger.info('Core Event enabled') + self.event_manager = EventManager() + self.event_manager.setup() + event_queue = self.event_manager.queue \ + if self.event_manager is not None \ + else None + # Setup remote executors only if + # --local-executor mode isn't enabled. + if self.remote_executors_enabled: + self.executors = ThreadlessPool( + flags=self.flags, + event_queue=event_queue, + executor_klass=RemoteFdExecutor, + ) + self.executors.setup() + # Setup acceptors self.acceptors = AcceptorPool( flags=self.flags, - work_klass=HttpProtocolHandler + listeners=self.listeners, + executor_queues=self.executors.work_queues if self.executors else [], + executor_pids=self.executors.work_pids if self.executors else [], + executor_locks=self.executors.work_locks if self.executors else [], + event_queue=event_queue, ) self.acceptors.setup() - self.write_pid_file() - return self + # Start SSH tunnel acceptor if enabled + if self.flags.enable_ssh_tunnel: + self.ssh_http_protocol_handler = SshHttpProtocolHandler( + flags=self.flags, + ) + self.ssh_tunnel_listener = SshTunnelListener( + flags=self.flags, + on_connection_callback=self.ssh_http_protocol_handler.on_connection, + ) + self.ssh_tunnel_listener.setup() + self.ssh_tunnel_listener.start_port_forward( + ('', self.flags.tunnel_remote_port), + ) + # TODO: May be close listener fd as we don't need it now + if threading.current_thread() == threading.main_thread(): + self._register_signals() - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + def shutdown(self) -> None: + if self.flags.enable_ssh_tunnel: + assert self.ssh_tunnel_listener is not None + self.ssh_tunnel_listener.shutdown() assert self.acceptors self.acceptors.shutdown() - self.indexer.terminate() - self.delete_pid_file() + if self.remote_executors_enabled: + assert self.executors + self.executors.shutdown() + if self.flags.enable_events: + assert self.event_manager is not None + self.event_manager.shutdown() + if self.listeners: + self.listeners.shutdown() + self._delete_port_file() + self._delete_pid_file() + + @property + def remote_executors_enabled(self) -> bool: + return self.flags.threadless and \ + not self.flags.local_executor + + def _write_pid_file(self) -> None: + if self.flags.pid_file: + with open(self.flags.pid_file, 'wb') as pid_file: + pid_file.write(bytes_(os.getpid())) + + def _delete_pid_file(self) -> None: + if self.flags.pid_file \ + and os.path.exists(self.flags.pid_file): + os.remove(self.flags.pid_file) + + def _write_port_file(self) -> None: + if self.flags.port_file: + with open(self.flags.port_file, 'wb') as port_file: + if not self.flags.unix_socket_path: + port_file.write(bytes_(self.flags.port)) + port_file.write(b'\n') + for port in self.flags.ports: + port_file.write(bytes_(port)) + port_file.write(b'\n') + + def _delete_port_file(self) -> None: + if self.flags.port_file \ + and os.path.exists(self.flags.port_file): + os.remove(self.flags.port_file) + + def _register_signals(self) -> None: + # TODO: Define SIGUSR1, SIGUSR2 + signal.signal(signal.SIGINT, self._handle_exit_signal) + signal.signal(signal.SIGTERM, self._handle_exit_signal) + if not IS_WINDOWS: + if hasattr(signal, 'SIGINFO'): + signal.signal( # pragma: no cover + signal.SIGINFO, # pylint: disable=E1101 + self._handle_siginfo, + ) + signal.signal(signal.SIGHUP, self._handle_exit_signal) + # TODO: SIGQUIT is ideally meant to terminate with core dumps + signal.signal(signal.SIGQUIT, self._handle_exit_signal) + + @staticmethod + def _handle_exit_signal(signum: int, _frame: Any) -> None: + logger.info('Received signal %d' % signum) + sys.exit(0) + def _handle_siginfo(self, _signum: int, _frame: Any) -> None: + pprint.pprint(self.flags.__dict__) # pragma: no cover -@contextlib.contextmanager -def start( - input_args: Optional[List[str]] = None, - **opts: Any) -> Generator[Proxy, None, None]: - """Deprecated. Kept for backward compatibility. - New users must directly use proxy.Proxy context manager class.""" - try: - with Proxy(input_args, **opts) as p: - yield p - except KeyboardInterrupt: - pass +def sleep_loop() -> None: + while True: + try: + time.sleep(1) + except KeyboardInterrupt: + break -def main( - input_args: Optional[List[str]] = None, - **opts: Any) -> None: - try: - with Proxy(input_args=input_args, **opts): - # TODO: Introduce cron feature instead of mindless sleep - while True: - time.sleep(1) - except KeyboardInterrupt: - pass +def main(**opts: Any) -> None: + with Proxy(sys.argv[1:], **opts): + sleep_loop() def entry_point() -> None: - main(input_args=sys.argv[1:]) + main() diff --git a/tests/plugin/utils.py b/tests/plugin/utils.py index 093436796..ba8b9a9e4 100644 --- a/tests/plugin/utils.py +++ b/tests/plugin/utils.py @@ -12,7 +12,7 @@ from proxy.http.proxy import HttpProxyBasePlugin from proxy.plugin import ModifyPostDataPlugin, ProposedRestApiPlugin, RedirectToCustomServerPlugin, \ - FilterByUpstreamHostPlugin, CacheResponsesPlugin, ManInTheMiddlePlugin + CacheResponsesPlugin, ManInTheMiddlePlugin def get_plugin_by_test_name(test_name: str) -> Type[HttpProxyBasePlugin]: @@ -23,8 +23,6 @@ def get_plugin_by_test_name(test_name: str) -> Type[HttpProxyBasePlugin]: plugin = ProposedRestApiPlugin elif test_name == 'test_redirect_to_custom_server_plugin': plugin = RedirectToCustomServerPlugin - elif test_name == 'test_filter_by_upstream_host_plugin': - plugin = FilterByUpstreamHostPlugin elif test_name == 'test_cache_responses_plugin': plugin = CacheResponsesPlugin elif test_name == 'test_man_in_the_middle_plugin': diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index c927a2a12..000000000 --- a/tests/test_main.py +++ /dev/null @@ -1,206 +0,0 @@ -# -*- coding: utf-8 -*- -""" - proxy.py - ~~~~~~~~ - ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on - Network monitoring, controls & Application development, testing, debugging. - - :copyright: (c) 2013-present by Abhinav Singh and contributors. - :license: BSD, see LICENSE for more details. -""" -import unittest -import logging -import tempfile -import os - -from unittest import mock -from typing import List - -from proxy.proxy import main -from proxy.common.flags import Flags -from proxy.common.utils import bytes_ -from proxy.http.handler import HttpProtocolHandler - -from proxy.common.constants import DEFAULT_LOG_LEVEL, DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_BASIC_AUTH -from proxy.common.constants import DEFAULT_TIMEOUT, DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HTTP_PROXY -from proxy.common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_EVENTS, DEFAULT_ENABLE_DEVTOOLS -from proxy.common.constants import DEFAULT_ENABLE_WEB_SERVER, DEFAULT_THREADLESS, DEFAULT_CERT_FILE, DEFAULT_KEY_FILE -from proxy.common.constants import DEFAULT_CA_CERT_FILE, DEFAULT_CA_KEY_FILE, DEFAULT_CA_SIGNING_KEY_FILE -from proxy.common.constants import DEFAULT_PAC_FILE, DEFAULT_PLUGINS, DEFAULT_PID_FILE, DEFAULT_PORT -from proxy.common.constants import DEFAULT_NUM_WORKERS, DEFAULT_OPEN_FILE_LIMIT, DEFAULT_IPV6_HOSTNAME -from proxy.common.constants import DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CLIENT_RECVBUF_SIZE -from proxy.common.version import __version__ - - -def get_temp_file(name: str) -> str: - return os.path.join(tempfile.gettempdir(), name) - - -class TestMain(unittest.TestCase): - - @staticmethod - def mock_default_args(mock_args: mock.Mock) -> None: - mock_args.version = False - mock_args.cert_file = DEFAULT_CERT_FILE - mock_args.key_file = DEFAULT_KEY_FILE - mock_args.ca_key_file = DEFAULT_CA_KEY_FILE - mock_args.ca_cert_file = DEFAULT_CA_CERT_FILE - mock_args.ca_signing_key_file = DEFAULT_CA_SIGNING_KEY_FILE - mock_args.pid_file = DEFAULT_PID_FILE - mock_args.log_file = DEFAULT_LOG_FILE - mock_args.log_level = DEFAULT_LOG_LEVEL - mock_args.log_format = DEFAULT_LOG_FORMAT - mock_args.basic_auth = DEFAULT_BASIC_AUTH - mock_args.hostname = DEFAULT_IPV6_HOSTNAME - mock_args.port = DEFAULT_PORT - mock_args.num_workers = DEFAULT_NUM_WORKERS - mock_args.disable_http_proxy = DEFAULT_DISABLE_HTTP_PROXY - mock_args.enable_web_server = DEFAULT_ENABLE_WEB_SERVER - mock_args.pac_file = DEFAULT_PAC_FILE - mock_args.plugins = DEFAULT_PLUGINS - mock_args.server_recvbuf_size = DEFAULT_SERVER_RECVBUF_SIZE - mock_args.client_recvbuf_size = DEFAULT_CLIENT_RECVBUF_SIZE - mock_args.open_file_limit = DEFAULT_OPEN_FILE_LIMIT - mock_args.enable_static_server = DEFAULT_ENABLE_STATIC_SERVER - mock_args.enable_devtools = DEFAULT_ENABLE_DEVTOOLS - mock_args.devtools_event_queue = None - mock_args.devtools_ws_path = DEFAULT_DEVTOOLS_WS_PATH - mock_args.timeout = DEFAULT_TIMEOUT - mock_args.threadless = DEFAULT_THREADLESS - mock_args.enable_events = DEFAULT_ENABLE_EVENTS - - @mock.patch('time.sleep') - @mock.patch('proxy.proxy.Flags') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('logging.basicConfig') - def test_init_with_no_arguments( - self, - mock_logging_config: mock.Mock, - mock_acceptor_pool: mock.Mock, - mock_flags: mock.Mock, - mock_sleep: mock.Mock) -> None: - mock_sleep.side_effect = KeyboardInterrupt() - - input_args: List[str] = [] - flags = Flags.initialize(input_args=input_args) - mock_flags.initialize = lambda *args, **kwargs: flags - - main() - - mock_logging_config.assert_called_with( - level=logging.INFO, - format=DEFAULT_LOG_FORMAT - ) - mock_acceptor_pool.assert_called_with( - flags=flags, - work_klass=HttpProtocolHandler, - ) - mock_acceptor_pool.return_value.setup.assert_called() - mock_acceptor_pool.return_value.shutdown.assert_called() - mock_sleep.assert_called() - - @mock.patch('time.sleep') - @mock.patch('os.remove') - @mock.patch('os.path.exists') - @mock.patch('builtins.open') - @mock.patch('proxy.proxy.Flags.init_parser') - @mock.patch('proxy.proxy.AcceptorPool') - def test_pid_file_is_written_and_removed( - self, - mock_acceptor_pool: mock.Mock, - mock_init_parser: mock.Mock, - mock_open: mock.Mock, - mock_exists: mock.Mock, - mock_remove: mock.Mock, - mock_sleep: mock.Mock) -> None: - pid_file = get_temp_file('pid') - mock_sleep.side_effect = KeyboardInterrupt() - mock_args = mock_init_parser.return_value.parse_args.return_value - self.mock_default_args(mock_args) - mock_args.pid_file = pid_file - main(['--pid-file', pid_file]) - mock_init_parser.assert_called() - mock_acceptor_pool.assert_called() - mock_acceptor_pool.return_value.setup.assert_called() - mock_open.assert_called_with(pid_file, 'wb') - mock_open.return_value.__enter__.return_value.write.assert_called_with( - bytes_(os.getpid())) - mock_exists.assert_called_with(pid_file) - mock_remove.assert_called_with(pid_file) - - @mock.patch('time.sleep') - @mock.patch('proxy.proxy.Flags') - @mock.patch('proxy.proxy.AcceptorPool') - def test_basic_auth( - self, - mock_acceptor_pool: mock.Mock, - mock_flags: mock.Mock, - mock_sleep: mock.Mock) -> None: - mock_sleep.side_effect = KeyboardInterrupt() - - input_args = ['--basic-auth', 'user:pass'] - flags = Flags.initialize(input_args=input_args) - mock_flags.initialize = lambda *args, **kwargs: flags - - main(input_args=input_args) - mock_acceptor_pool.assert_called_with( - flags=flags, - work_klass=HttpProtocolHandler) - self.assertEqual( - flags.auth_code, - b'Basic dXNlcjpwYXNz') - - @mock.patch('time.sleep') - @mock.patch('builtins.print') - @mock.patch('proxy.proxy.Flags') - @mock.patch('proxy.proxy.AcceptorPool') - @mock.patch('proxy.proxy.Flags.is_py3') - def test_main_py3_runs( - self, - mock_is_py3: mock.Mock, - mock_acceptor_pool: mock.Mock, - mock_flags: mock.Mock, - mock_print: mock.Mock, - mock_sleep: mock.Mock) -> None: - mock_sleep.side_effect = KeyboardInterrupt() - - input_args = ['--basic-auth', 'user:pass'] - flags = Flags.initialize(input_args=input_args) - mock_flags.initialize = lambda *args, **kwargs: flags - - mock_is_py3.return_value = True - main(num_workers=1) - mock_is_py3.assert_called() - mock_print.assert_not_called() - mock_acceptor_pool.assert_called() - mock_acceptor_pool.return_value.setup.assert_called() - - @mock.patch('builtins.print') - @mock.patch('proxy.proxy.Flags.is_py3') - def test_main_py2_exit( - self, - mock_is_py3: mock.Mock, - mock_print: mock.Mock) -> None: - mock_is_py3.return_value = False - with self.assertRaises(SystemExit) as e: - main(num_workers=1) - mock_print.assert_called_with( - 'DEPRECATION: "develop" branch no longer supports Python 2.7. Kindly upgrade to Python 3+. ' - 'If for some reasons you cannot upgrade, consider using "master" branch or simply ' - '"pip install proxy.py==0.3".' - '\n\n' - 'DEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. ' - 'Please upgrade your Python as Python 2.7 won\'t be maintained after that date. ' - 'A future version of pip will drop support for Python 2.7.' - ) - self.assertEqual(e.exception.code, 1) - mock_is_py3.assert_called() - - @mock.patch('builtins.print') - def test_main_version( - self, - mock_print: mock.Mock) -> None: - with self.assertRaises(SystemExit) as e: - main(['--version']) - mock_print.assert_called_with(__version__) - self.assertEqual(e.exception.code, 0)