diff --git a/CHANGES.md b/CHANGES.md
index 7e47361ea..9ec1e4455 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -28,6 +28,8 @@
## Bug Fixes
+* [GH-650](https://github.com/apache/mina-sshd/issues/650) Use the correct key from a user certificate in server-side pubkey auth
+
## New Features
## Potential Compatibility Issues
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
index 37719c646..8e1e57173 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
@@ -124,12 +124,16 @@ public Boolean doAuth(Buffer buffer, boolean init) throws Exception {
log.debug("doAuth({}@{}) verify key type={}, factories={}, fingerprint={}",
username, session, alg, NamedResource.getNames(factories), KeyUtils.getFingerPrint(key));
}
-
+ /*
+ * When users employ cert authentication, need to use the public key in the cert for signing
+ * and cannot use the cert itself directly for signing
+ */
+ PublicKey verifyKey = key instanceof OpenSshCertificate ? ((OpenSshCertificate) key).getCertPubKey() : key;
Signature verifier = ValidateUtils.checkNotNull(
NamedFactory.create(factories, alg),
"No verifier located for algorithm=%s",
alg);
- verifier.initVerifier(session, key);
+ verifier.initVerifier(session, verifyKey);
buffer.wpos(oldLim);
byte[] sig = hasSig ? buffer.getBytes() : null;
diff --git a/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java b/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java
index 80ca9eadf..134ab37e6 100644
--- a/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java
@@ -32,8 +32,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Stream;
+import org.apache.sshd.certificate.OpenSshCertificateBuilder;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.auth.keyboard.UserInteraction;
import org.apache.sshd.client.auth.pubkey.PublicKeyAuthenticationReporter;
@@ -45,6 +48,7 @@
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.config.keys.FilePasswordProvider;
import org.apache.sshd.common.config.keys.KeyUtils;
+import org.apache.sshd.common.config.keys.OpenSshCertificate;
import org.apache.sshd.common.keyprovider.KeyIdentityProvider;
import org.apache.sshd.common.keyprovider.KeyPairProvider;
import org.apache.sshd.common.session.SessionContext;
@@ -64,13 +68,9 @@
import org.junit.jupiter.api.MethodOrderer.MethodName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.assertSame;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
/**
* @author Apache MINA SSHD Project
@@ -459,4 +459,77 @@ void rsaAuthenticationOldServer() throws Exception {
}
}
}
+
+ @ParameterizedTest(name = "test certificates issued using the {0} algorithm")
+ @MethodSource("certificateAlgorithms")
+ void testCertificateWithDifferentAlgorithms(String keyAlgorithm, int keySize, String signatureAlgorithm) throws Exception {
+ // 1. Generating a user key pair
+ KeyPair userkey = CommonTestSupportUtils.generateKeyPair(keyAlgorithm, keySize);
+ // 2. Generating CA key pair
+ KeyPair caKeypair = CommonTestSupportUtils.generateKeyPair(keyAlgorithm, keySize);
+
+ // 3. Building openSshCertificate
+ OpenSshCertificate signedCert = OpenSshCertificateBuilder.userCertificate()
+ .serial(System.currentTimeMillis())
+ .publicKey(userkey.getPublic())
+ .id("test-cert-" + keyAlgorithm)
+ .validBefore(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(1))
+ .principals(Collections.singletonList("user01"))
+ .criticalOptions(Collections.emptyList())
+ .extensions(Arrays.asList(
+ new OpenSshCertificate.CertificateOption("permit-X11-forwarding"),
+ new OpenSshCertificate.CertificateOption("permit-agent-forwarding")))
+ .sign(caKeypair, signatureAlgorithm);
+
+ // 4. Configuring the ssh server
+ sshd.setPasswordAuthenticator(RejectAllPasswordAuthenticator.INSTANCE);
+ sshd.setKeyboardInteractiveAuthenticator(KeyboardInteractiveAuthenticator.NONE);
+ CoreTestSupportUtils.setupFullSignaturesSupport(sshd);
+
+ sshd.setUserAuthFactories(Collections.singletonList(
+ new org.apache.sshd.server.auth.pubkey.UserAuthPublicKeyFactory()));
+
+ AtomicInteger authAttempts = new AtomicInteger(0);
+ sshd.setPublickeyAuthenticator((username, key, session) -> {
+ authAttempts.incrementAndGet();
+ if (key instanceof OpenSshCertificate) {
+ OpenSshCertificate cert = (OpenSshCertificate) key;
+ return KeyUtils.compareKeys(cert.getCaPubKey(), caKeypair.getPublic());
+ }
+ return false;
+ });
+
+ // 5. Testing Client Authentication
+ try (SshClient client = setupTestClient()) {
+ CoreTestSupportUtils.setupFullSignaturesSupport(client);
+ client.setUserAuthFactories(Collections.singletonList(
+ new org.apache.sshd.client.auth.pubkey.UserAuthPublicKeyFactory()));
+
+ client.start();
+
+ try (ClientSession session = client.connect("user01", TEST_LOCALHOST, port)
+ .verify(CONNECT_TIMEOUT)
+ .getSession()) {
+
+ KeyPair certKeyPair = new KeyPair(signedCert, userkey.getPrivate());
+ session.addPublicKeyIdentity(certKeyPair);
+
+ AuthFuture auth = session.auth();
+ assertTrue(auth.verify(AUTH_TIMEOUT).isSuccess());
+ assertEquals(2, authAttempts.get(), "There should be two attempts to authenticate using the certificate");
+ } finally {
+ client.stop();
+ }
+ }
+ }
+
+ private static Stream certificateAlgorithms() {
+ return Stream.of(
+ // key size, signature algorithm, algorithm name
+ Arguments.of(KeyUtils.RSA_ALGORITHM, 2048, "rsa-sha2-512"),
+ Arguments.of(KeyUtils.RSA_ALGORITHM, 2048, "rsa-sha2-256"),
+ Arguments.of(KeyUtils.EC_ALGORITHM, 256, "ecdsa-sha2-nistp256"),
+ Arguments.of(KeyUtils.EC_ALGORITHM, 384, "ecdsa-sha2-nistp384"),
+ Arguments.of(KeyUtils.EC_ALGORITHM, 521, "ecdsa-sha2-nistp521"));
+ }
}