diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 1f8865afb..739df808f 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -417,6 +417,12 @@ def prepare_key(self, key): except ValueError: key = load_pem_private_key(key, password=None) + # Explicit check the key to prevent confusing errors from cryptography + if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): + raise InvalidKeyError( + "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" + ) + return key def sign(self, msg, key): diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index b6a73fc4d..f4ab75bbc 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -494,6 +494,18 @@ def test_ec_verify_should_return_false_if_signature_wrong_length(self): result = algo.verify(message, pub_key, sig) assert not result + @crypto_required + def test_ec_should_throw_exception_on_wrong_key(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + with pytest.raises(InvalidKeyError): + with open(key_path("testkey_rsa.priv")) as keyfile: + algo.prepare_key(keyfile.read()) + + with pytest.raises(InvalidKeyError): + with open(key_path("testkey2_rsa.pub.pem")) as pem_key: + algo.prepare_key(pem_key.read()) + @crypto_required def test_rsa_pss_sign_then_verify_should_return_true(self): algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256)