diff --git a/src/requests/auth.py b/src/requests/auth.py index 4a7ce6dc14..da086f6d4a 100644 --- a/src/requests/auth.py +++ b/src/requests/auth.py @@ -123,6 +123,45 @@ def init_per_thread_state(self): self._thread_local.pos = None self._thread_local.num_401_calls = None + @staticmethod + def _encode_data(data, codec="latin-1"): + """ + This function encodes input data to bytes using the specified + encoding (default is Latin-1). It returns the encoded data as bytes. + :rtype: tuple[bytes, Optional[str]] + """ + if type(data) is bytes: + return data, None + try: + return str(data).encode(codec), codec + except UnicodeEncodeError: + warnings.warn( + "This data will be encoded with UTF-8 because the provided " + "encoding could not handle some characters.", + category=UnicodeWarning, + ) + if codec != "utf-8": + return HTTPDigestAuth._encode_data(data, "utf-8") + else: + raise UnicodeEncodeError("Cannot encode the provided data...") + + @staticmethod + def _decode_data(data, codec): + """ + This function decodes input data from bytes using the specified + encoding. It returns the decoded data as a string. + :rtype: str + """ + if type(data) is not bytes: + return data + if codec is None: + warnings.warn( + "No encoding provided. The data will be decoded using UTF-8.", + category=UnicodeWarning, + ) + codec = "utf-8" + return data.decode(codec) + def build_digest_header(self, method, url): """ :rtype: str @@ -186,7 +225,11 @@ def sha512_utf8(x): if p_parsed.query: path += f"?{p_parsed.query}" - A1 = f"{self.username}:{realm}:{self.password}" + username, username_codec = self._encode_data(self.username) + realm, realm_codec = self._encode_data(realm) + password, _ = self._encode_data(self.password) + + A1 = b":".join([username, realm, password]) A2 = f"{method}:{path}" HA1 = hash_utf8(A1) @@ -218,9 +261,11 @@ def sha512_utf8(x): self._thread_local.last_nonce = nonce # XXX should the partial digests be encoded too? + decoded_username = self._decode_data(username, username_codec) + decoded_realm = self._decode_data(realm, realm_codec) base = ( - f'username="{self.username}", realm="{realm}", nonce="{nonce}", ' - f'uri="{path}", response="{respdig}"' + f'username="{decoded_username}", realm="{decoded_realm}", ' + f'nonce="{nonce}", uri="{path}", response="{respdig}"' ) if opaque: base += f', opaque="{opaque}"' diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000000..a0328e7f31 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,27 @@ +import pytest + +# from requests.auth import HTTPDigestAuth +from src.requests.auth import HTTPDigestAuth + + +class TestDigestAuth: + def _build_a_digest_auth(self, user, password): + auth = HTTPDigestAuth(user, password) + auth.init_per_thread_state() + auth._thread_local.chal["realm"] = "eggs" + auth._thread_local.chal["nonce"] = "chips" + return auth.build_digest_header("GET", "https://www.example.com/") + + @pytest.mark.parametrize( + "username, password", + ( + ("spam", "ham"), + ("имя", "пароль"), + ), + ) + def test_digestauth_encode_consistency(self, username, password): + auth = username, password + str_auth = self._build_a_digest_auth(*auth) + bauth = username.encode("utf-8"), password.encode("utf-8") + bin_auth = self._build_a_digest_auth(*bauth) + assert str_auth == bin_auth