diff --git a/README.md b/README.md index 9034319c..11371f5e 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Quick links: The library supports the following Java environments: - Java 8 (or higher) -Current version - 1.23.0 +Current version - 1.23.1 You can find the changes for each version in the [change log](https://github.com/AzureAD/microsoft-authentication-library-for-java/blob/main/msal4j-sdk/changelog.txt). @@ -28,13 +28,13 @@ Find [the latest package in the Maven repository](https://mvnrepository.com/arti com.microsoft.azure msal4j - 1.23.0 + 1.23.1 ``` ### Gradle ```gradle -implementation group: 'com.microsoft.azure', name: 'com.microsoft.aad.msal4j', version: '1.23.0' +implementation group: 'com.microsoft.azure', name: 'com.microsoft.aad.msal4j', version: '1.23.1' ``` ## Usage diff --git a/changelog.txt b/changelog.txt index 69dd36b0..7b0f733b 100644 --- a/changelog.txt +++ b/changelog.txt @@ -1,3 +1,10 @@ +Version 1.23.1 +============= +- Fix regression and other issues related to client credentials (#986) + - Fix for regression after latest release where certificate-based assertions were never refreshed (#984) + - Fix/improve behavior to properly handle callback-based assertions set at the application level (#879, #977) + - Generally improve internal assertion behavior by making assertions entirely per-request + Version 1.23.0 ============= - Reduced dependency footprint by removing third-party libraries (#909): diff --git a/msal4j-sdk/README.md b/msal4j-sdk/README.md index 0edddadb..20b006c4 100644 --- a/msal4j-sdk/README.md +++ b/msal4j-sdk/README.md @@ -16,7 +16,7 @@ Quick links: The library supports the following Java environments: - Java 8 (or higher) -Current version - 1.23.0 +Current version - 1.23.1 You can find the changes for each version in the [change log](https://github.com/AzureAD/microsoft-authentication-library-for-java/blob/master/changelog.txt). @@ -28,13 +28,13 @@ Find [the latest package in the Maven repository](https://mvnrepository.com/arti com.microsoft.azure msal4j - 1.23.0 + 1.23.1 ``` ### Gradle ```gradle -compile group: 'com.microsoft.azure', name: 'msal4j', version: '1.23.0' +compile group: 'com.microsoft.azure', name: 'msal4j', version: '1.23.1' ``` ## Usage diff --git a/msal4j-sdk/bnd.bnd b/msal4j-sdk/bnd.bnd index cee9edeb..bd075775 100644 --- a/msal4j-sdk/bnd.bnd +++ b/msal4j-sdk/bnd.bnd @@ -1,2 +1,2 @@ -Export-Package: com.microsoft.aad.msal4j;version="1.23.0" +Export-Package: com.microsoft.aad.msal4j;version="1.23.1" Automatic-Module-Name: com.microsoft.aad.msal4j diff --git a/msal4j-sdk/pom.xml b/msal4j-sdk/pom.xml index e1d20cad..15e4d751 100644 --- a/msal4j-sdk/pom.xml +++ b/msal4j-sdk/pom.xml @@ -3,7 +3,7 @@ 4.0.0 com.microsoft.azure msal4j - 1.23.0 + 1.23.1 jar msal4j diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByAuthorizationGrantSupplier.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByAuthorizationGrantSupplier.java index 7e43058b..2c8ecd20 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByAuthorizationGrantSupplier.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByAuthorizationGrantSupplier.java @@ -52,9 +52,7 @@ AuthenticationResult execute() throws Exception { requestAuthority = clientApplication.authenticationAuthority; } - if (requestAuthority.authorityType == AuthorityType.AAD) { - requestAuthority = getAuthorityWithPrefNetworkHost(requestAuthority.authority()); - } + requestAuthority = getAuthorityWithPrefNetworkHost(requestAuthority.authority()); try { return clientApplication.acquireTokenCommon(msalRequest, requestAuthority); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenSilentSupplier.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenSilentSupplier.java index 1b073f17..9fdfc14a 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenSilentSupplier.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenSilentSupplier.java @@ -24,11 +24,7 @@ class AcquireTokenSilentSupplier extends AuthenticationResultSupplier { @Override AuthenticationResult execute() throws Exception { boolean shouldRefresh; - Authority requestAuthority = silentRequest.requestAuthority(); - if (requestAuthority.authorityType != AuthorityType.B2C) { - requestAuthority = - getAuthorityWithPrefNetworkHost(silentRequest.requestAuthority().authority()); - } + Authority requestAuthority = getAuthorityWithPrefNetworkHost(silentRequest.requestAuthority().authority()); AuthenticationResult res; if (silentRequest.parameters().account() == null) { diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/Authority.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/Authority.java index 27534329..53058e30 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/Authority.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/Authority.java @@ -131,6 +131,21 @@ static void validateAuthority(URL authorityUrl) { } } + /** + * Creates a new Authority instance with a different tenant. + * This is useful when overriding the tenant at request level. + * + * @param originalAuthority The original authority to base the new one on + * @param newTenant The new tenant to use in the authority URL + * @return A new Authority instance with the specified tenant + */ + static Authority replaceTenant(Authority originalAuthority, String newTenant) throws MalformedURLException { + String authorityString = originalAuthority.canonicalAuthorityUrl().toString(); + authorityString = authorityString.replace(originalAuthority.tenant, newTenant); + + return createAuthority(new URL(authorityString)); + } + static String getTenant(URL authorityUrl, AuthorityType authorityType) { String[] segments = authorityUrl.getPath().substring(1).split("/"); if (authorityType == AuthorityType.B2C) { diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientAssertion.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientAssertion.java index 39f24ad4..1126b461 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientAssertion.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientAssertion.java @@ -4,21 +4,71 @@ package com.microsoft.aad.msal4j; import java.util.Objects; +import java.util.concurrent.Callable; final class ClientAssertion implements IClientAssertion { static final String ASSERTION_TYPE_JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; private final String assertion; + private final Callable assertionProvider; + /** + * Constructor that accepts a static assertion string + * + * @param assertion The JWT assertion string to use + * @throws NullPointerException if assertion is null or empty + */ ClientAssertion(final String assertion) { if (StringHelper.isBlank(assertion)) { throw new NullPointerException("assertion"); } this.assertion = assertion; + this.assertionProvider = null; } + /** + * Constructor that accepts a callable that provides the assertion string + * + * @param assertionProvider A callable that returns a JWT assertion string + * @throws NullPointerException if assertionProvider is null + */ + ClientAssertion(final Callable assertionProvider) { + if (assertionProvider == null) { + throw new NullPointerException("assertionProvider"); + } + + this.assertion = null; + this.assertionProvider = assertionProvider; + } + + /** + * Gets the JWT assertion for client authentication. + * If this ClientAssertion was created with a Callable, the callable will be + * invoked each time this method is called to generate a fresh assertion. + * + * @return A JWT assertion string + * @throws MsalClientException if the assertion provider returns null/empty or throws an exception + */ public String assertion() { + if (assertionProvider != null) { + try { + String generatedAssertion = assertionProvider.call(); + + if (StringHelper.isBlank(generatedAssertion)) { + throw new MsalClientException( + "Assertion provider returned null or empty assertion", + AuthenticationErrorCode.INVALID_JWT); + } + + return generatedAssertion; + } catch (MsalClientException ex) { + throw ex; + } catch (Exception ex) { + throw new MsalClientException(ex); + } + } + return this.assertion; } @@ -30,11 +80,24 @@ public boolean equals(Object o) { if (!(o instanceof ClientAssertion)) return false; ClientAssertion other = (ClientAssertion) o; + + // For assertion providers, we consider them equal if they're the same object + if (this.assertionProvider != null && other.assertionProvider != null) { + return this.assertionProvider == other.assertionProvider; + } + + // For static assertions, compare the assertion strings return Objects.equals(assertion(), other.assertion()); } @Override public int hashCode() { + // For assertion providers, use the provider's identity hash code + if (assertionProvider != null) { + return System.identityHashCode(assertionProvider); + } + + // For static assertions, hash the assertion string int result = 1; result = result * 59 + (this.assertion == null ? 43 : this.assertion.hashCode()); return result; diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java index fe3047e8..8b1aac33 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java @@ -55,6 +55,33 @@ public List getEncodedPublicKeyCertificateChain() throws CertificateEnco return result; } + /** + * Gets a newly created JWT assertion using the certificate. + *

+ * This method creates a fresh JWT assertion on each call, which prevents issues + * with token expiration and ensures each request has a unique assertion. + * + * @param authority The authority for which the assertion is being created, must not be null + * @param clientId The client ID of the application, used as the subject of the JWT + * @param sendX5c Whether to include the x5c claim (certificate chain) in the JWT + * @return A JWT assertion for client authentication + * @throws NullPointerException if authority is null + */ + public String getAssertion(Authority authority, String clientId, boolean sendX5c) { + if (authority == null) { + throw new NullPointerException("Authority cannot be null"); + } + + boolean useSha1 = Authority.detectAuthorityType(authority.canonicalAuthorityUrl()) == AuthorityType.ADFS; + + return JwtHelper.buildJwt( + clientId, + this, + authority.selfSignedJwtAudience(), + sendX5c, + useSha1).assertion(); + } + static ClientCertificate create(InputStream pkcs12Certificate, String password) throws KeyStoreException, NoSuchProviderException, NoSuchAlgorithmException, CertificateException, IOException, UnrecoverableKeyException { diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialFactory.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialFactory.java index 1d93319a..1a915885 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialFactory.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialFactory.java @@ -91,15 +91,18 @@ public static IClientAssertion createFromClientAssertion(String clientAssertion) /** * Static method to create a {@link ClientAssertion} instance from a provided Callable. + * The callable will be invoked each time the assertion is needed, allowing for dynamic + * generation of assertions. * * @param callable Callable that produces a JWT token encoded as a base64 URL encoded string - * @return {@link ClientAssertion} + * @return {@link ClientAssertion} that will invoke the callable each time assertion() is called + * @throws NullPointerException if callable is null */ - public static IClientAssertion createFromCallback(Callable callable) throws ExecutionException, InterruptedException { - ExecutorService executor = Executors.newSingleThreadExecutor(); - - Future future = executor.submit(callable); + public static IClientAssertion createFromCallback(Callable callable) { + if (callable == null) { + throw new NullPointerException("callable"); + } - return new ClientAssertion(future.get()); + return new ClientAssertion(callable); } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java index abe8fb09..51fd4610 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java @@ -19,17 +19,14 @@ */ public class ConfidentialClientApplication extends AbstractClientApplicationBase implements IConfidentialClientApplication { - private ClientCertificate clientCertificate; - String assertion; - String secret; + IClientCredential clientCredential; + private boolean sendX5c; /** AppTokenProvider creates a Credential from a function that provides access tokens. The function must be concurrency safe. This is intended only to allow the Azure SDK to cache MSI tokens. It isn't useful to applications in general because the token provider must implement all authentication logic. */ public Function> appTokenProvider; - private boolean sendX5c; - @Override public CompletableFuture acquireToken(ClientCredentialParameters parameters) { validateNotNull("parameters", parameters); @@ -73,43 +70,11 @@ private ConfidentialClientApplication(Builder builder) { log = LoggerFactory.getLogger(ConfidentialClientApplication.class); - initClientAuthentication(builder.clientCredential); + this.clientCredential = builder.clientCredential; this.tenant = this.authenticationAuthority.tenant; } - private void initClientAuthentication(IClientCredential clientCredential) { - validateNotNull("clientCredential", clientCredential); - - if (clientCredential instanceof ClientSecret) { - this.secret = ((ClientSecret) clientCredential).clientSecret(); - } else if (clientCredential instanceof ClientCertificate) { - this.clientCertificate = (ClientCertificate) clientCredential; - this.assertion = getAssertionString(clientCredential); - } else if (clientCredential instanceof ClientAssertion) { - this.assertion = getAssertionString(clientCredential); - } else { - throw new IllegalArgumentException("Unsupported client credential"); - } - } - - String getAssertionString(IClientCredential clientCredential) { - if (clientCredential instanceof ClientCertificate) { - boolean useSha1 = Authority.detectAuthorityType(this.authenticationAuthority.canonicalAuthorityUrl()) == AuthorityType.ADFS; - - return JwtHelper.buildJwt( - clientId(), - clientCertificate, - this.authenticationAuthority.selfSignedJwtAudience(), - sendX5c, - useSha1).assertion(); - } else if (clientCredential instanceof ClientAssertion) { - return ((ClientAssertion) clientCredential).assertion(); - } else { - throw new IllegalArgumentException("Unsupported client credential"); - } - } - /** * Creates instance of Builder of ConfidentialClientApplication * @@ -137,6 +102,9 @@ public static class Builder extends AbstractClientApplicationBase.Builder queryParameters, + ConfidentialClientApplication application) { + IClientCredential credentialToUse = application.clientCredential; + Authority authorityToUse = application.authenticationAuthority; + + // A ClientCredentialRequest may have parameters which override the credentials used to build the application. + if (msalRequest instanceof ClientCredentialRequest) { + ClientCredentialParameters parameters = ((ClientCredentialRequest) msalRequest).parameters; + + if (parameters.clientCredential() != null) { + credentialToUse = parameters.clientCredential(); + } + + if (parameters.tenant() != null) { + try { + authorityToUse = Authority.replaceTenant(authorityToUse, parameters.tenant()); + } catch (MalformedURLException e) { + log.warn("Could not create authority with tenant override: {}", e.getMessage()); } } } - oauthHttpRequest.setQuery(StringHelper.serializeQueryParameters(queryParameters)); + // Quick return if no credential is provided + if (credentialToUse == null) { + return; + } + + if (credentialToUse instanceof ClientSecret) { + // For client secret, add client_secret parameter + queryParameters.put("client_secret", ((ClientSecret) credentialToUse).clientSecret()); + } else if (credentialToUse instanceof ClientAssertion) { + // For client assertion, add client_assertion and client_assertion_type parameters + addJWTBearerAssertionParams(queryParameters, ((ClientAssertion) credentialToUse).assertion()); + } else if (credentialToUse instanceof ClientCertificate) { + // For client certificate, generate a new assertion and add it to the request + ClientCertificate certificate = (ClientCertificate) credentialToUse; + String assertion = certificate.getAssertion( + authorityToUse, + application.clientId(), + application.sendX5c()); + addJWTBearerAssertionParams(queryParameters, assertion); + } } + /** + * Adds the JWT bearer token assertion parameters to the request + * + * @param queryParameters The map of query parameters to add to + * @param assertion The JWT assertion string + */ private void addJWTBearerAssertionParams(Map queryParameters, String assertion) { queryParameters.put("client_assertion", assertion); queryParameters.put("client_assertion_type", ClientAssertion.ASSERTION_TYPE_JWT_BEARER); diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCertificateTest.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCertificateTest.java index 99eb8649..ba9c34ef 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCertificateTest.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCertificateTest.java @@ -3,12 +3,12 @@ package com.microsoft.aad.msal4j; -import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT; import com.nimbusds.jwt.SignedJWT; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -20,8 +20,9 @@ import java.security.*; import java.security.cert.CertificateException; import java.security.interfaces.RSAPrivateKey; -import java.text.ParseException; import java.util.*; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; @TestInstance(TestInstance.Lifecycle.PER_CLASS) class ClientCertificateTest { @@ -77,12 +78,17 @@ void testIClientCertificateInterface_CredentialFactoryUsesSha256() throws Except HttpRequest request = parameters.getArgument(0); String requestBody = request.body(); - SignedJWT signedJWT = SignedJWT.parse(cca.assertion); + String clientAssertion = extractClientAssertion(requestBody); - if (requestBody.contains(cca.assertion) - && signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) { - return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)); + if (clientAssertion != null) { + SignedJWT signedJWT = SignedJWT.parse(clientAssertion); + if (signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) { + return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)); + } } + + //If the client assertion is null or does not contain the x5t#S256 header, + // that indicates a problem in assertion generation and this test should fail. return null; }); @@ -94,6 +100,228 @@ void testIClientCertificateInterface_CredentialFactoryUsesSha256() throws Except assertEquals("accessTokenSha256", result.accessToken()); } + @Test + void testClientCertificate_GeneratesNewAssertionEachTime() throws Exception { + DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + List capturedAssertions = new ArrayList<>(); + + ConfidentialClientApplication cca = + ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromCertificate(TestHelper.getPrivateKey(), TestHelper.getX509Cert())) + .authority("https://login.microsoftonline.com/tenant") + .instanceDiscovery(false) + .validateAuthority(false) + .httpClient(httpClientMock) + .build(); + + // Mock the HTTP client to capture assertions from each request + when(httpClientMock.send(any(HttpRequest.class))).thenAnswer(parameters -> { + HttpRequest request = parameters.getArgument(0); + String requestBody = request.body(); + + String clientAssertion = extractClientAssertion(requestBody); + if (clientAssertion != null) { + capturedAssertions.add(clientAssertion); + + // Verify it's a valid JWT with proper headers + SignedJWT signedJWT = SignedJWT.parse(clientAssertion); + if (signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) { + HashMap tokenResponseValues = new HashMap<>(); + tokenResponseValues.put("access_token", "access_token_" + capturedAssertions.size()); + return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)); + } + } + return null; + }); + + ClientCredentialParameters parameters = ClientCredentialParameters.builder(Collections.singleton("scopes")).skipCache(true).build(); + + // Make two token requests + IAuthenticationResult result1 = cca.acquireToken(parameters).get(); + IAuthenticationResult result2 = cca.acquireToken(parameters).get(); + + // Verify two unique assertions were generated + assertEquals(2, capturedAssertions.size(), "Two assertions should have been generated"); + assertNotEquals(capturedAssertions.get(0), capturedAssertions.get(1), + "Each token request should generate a unique assertion"); + + // Optional: Parse and verify JWT properties if needed + SignedJWT jwt1 = SignedJWT.parse(capturedAssertions.get(0)); + SignedJWT jwt2 = SignedJWT.parse(capturedAssertions.get(1)); + + // Different JTI (JWT ID) should be used for each assertion + assertNotEquals(jwt1.getJWTClaimsSet().getJWTID(), jwt2.getJWTClaimsSet().getJWTID(), + "Each assertion should have a unique JTI claim"); + + // Verify the tokens are different + assertNotEquals(result1.accessToken(), result2.accessToken(), + "The access tokens from each request should be different"); + } + + @Test + void testClientCertificate_TenantOverride() throws Exception { + DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + Map capturedTenants = new HashMap<>(); + + ConfidentialClientApplication cca = + ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromCertificate(TestHelper.getPrivateKey(), TestHelper.getX509Cert())) + .authority("https://login.microsoftonline.com/default-tenant") + .instanceDiscovery(false) + .validateAuthority(false) + .httpClient(httpClientMock) + .build(); + + // Mock the HTTP client to capture and analyze assertions from each request + when(httpClientMock.send(any(HttpRequest.class))).thenAnswer(parameters -> { + HttpRequest request = parameters.getArgument(0); + String requestBody = request.body(); + String url = request.url().toString(); + + // Capture which tenant was used in the authority + String tenant = extractTenantFromUrl(url); + + // Extract the assertion to verify its audience claim + String clientAssertion = extractClientAssertion(requestBody); + if (clientAssertion != null) { + SignedJWT signedJWT = SignedJWT.parse(clientAssertion); + + // Get the audience claim to verify it matches the tenant + String audience = signedJWT.getJWTClaimsSet().getAudience().get(0); + + // Store the tenant and audience for verification + capturedTenants.put(tenant, audience); + + // Verify it's a valid JWT with proper headers + if (signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) { + HashMap tokenResponseValues = new HashMap<>(); + tokenResponseValues.put("access_token", "access_token_for_" + tenant); + return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)); + } + } + return null; + }); + + // First request with default tenant + ClientCredentialParameters defaultParameters = ClientCredentialParameters.builder(Collections.singleton("scopes")) + .skipCache(true) + .build(); + IAuthenticationResult resultDefault = cca.acquireToken(defaultParameters).get(); + + // Second request with override tenant + String overrideTenant = "override-tenant"; + ClientCredentialParameters overrideParameters = ClientCredentialParameters.builder(Collections.singleton("scopes")) + .skipCache(true) + .tenant(overrideTenant) + .build(); + IAuthenticationResult resultOverride = cca.acquireToken(overrideParameters).get(); + + // Verify both requests were processed + assertEquals(2, capturedTenants.size(), "Two requests with different tenants should have been processed"); + + // Verify both tenants were used + assertTrue(capturedTenants.containsKey("default-tenant"), "Default tenant should have been used"); + assertTrue(capturedTenants.containsKey(overrideTenant), "Override tenant should have been used"); + + // Verify the audience in the JWT assertions reflects the different tenants + assertNotEquals( + capturedTenants.get("default-tenant"), + capturedTenants.get(overrideTenant), + "JWT audience should differ between default and override tenant" + ); + + // Verify the audience claims match the expected format with the correct tenant + assertTrue( + capturedTenants.get("default-tenant").contains("default-tenant"), + "Audience for default tenant should contain the default tenant name" + ); + assertTrue( + capturedTenants.get(overrideTenant).contains(overrideTenant), + "Audience for override tenant should contain the override tenant name" + ); + + // Verify different access tokens were returned + assertNotEquals(resultDefault.accessToken(), resultOverride.accessToken(), + "Access tokens should differ when using different tenants"); + } + + @Test + void testClientCertificate_TenantOverride_B2C() throws Exception { + DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + String replacementTenant = "overrideTenant"; + + ConfidentialClientApplication cca = + ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromCertificate(TestHelper.getPrivateKey(), TestHelper.getX509Cert())) + .b2cAuthority(TestConfiguration.B2C_AUTHORITY) + .instanceDiscovery(false) + .validateAuthority(false) + .httpClient(httpClientMock) + .build(); + + when(httpClientMock.send(any(HttpRequest.class))).thenAnswer(parameters -> { + HttpRequest request = parameters.getArgument(0); + String requestBody = request.body(); + String url = request.url().toString(); + + // Extract the assertion to verify its audience claim + String clientAssertion = extractClientAssertion(requestBody); + + if (clientAssertion != null && url.contains(replacementTenant)) { + HashMap tokenResponseValues = new HashMap<>(); + tokenResponseValues.put("access_token", "access_token_for_" + replacementTenant); + return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)); + } + + return null; + }); + + ClientCredentialParameters overrideParameters = ClientCredentialParameters.builder(Collections.singleton("scopes")) + .skipCache(true) + .tenant(replacementTenant) + .build(); + IAuthenticationResult result = cca.acquireToken(overrideParameters).get(); + + assertNotNull(result); + assertEquals("access_token_for_"+ replacementTenant, result.accessToken()); + } + + /** + * Extracts the tenant name from an authority URL + * @param url The full URL containing the tenant + * @return The tenant name + */ + private String extractTenantFromUrl(String url) { + // Authority URL format is typically https://login.microsoftonline.com/tenant/... + String[] parts = url.split("/"); + for (int i = 0; i < parts.length; i++) { + if (parts[i].equalsIgnoreCase("login.microsoftonline.com") && i + 1 < parts.length) { + return parts[i + 1]; + } + } + return null; + } + + /** + * Extracts the client_assertion value from a URL-encoded request body + * @param requestBody The request body string + * @return The extracted client assertion or null if not found + */ + private String extractClientAssertion(String requestBody) { + try { + // Split the request body into key-value pairs + String[] pairs = requestBody.split("&"); + for (String pair : pairs) { + // Find the client_assertion parameter + if (pair.startsWith("client_assertion=")) { + // Extract and URL-decode the value + return URLDecoder.decode(pair.substring("client_assertion=".length()), StandardCharsets.UTF_8.toString()); + } + } + } catch (Exception e) { + // In case of any parsing errors + System.err.println("Error extracting client assertion: " + e.getMessage()); + } + return null; + } + class TestClientCredential implements IClientCertificate { @Override public PrivateKey privateKey() { diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCredentialTest.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCredentialTest.java index 4b92b472..5c818757 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCredentialTest.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCredentialTest.java @@ -5,6 +5,8 @@ import java.util.Collections; import java.util.HashMap; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -26,8 +28,9 @@ void testAssertionNullAndEmpty() { assertThrows(NullPointerException.class, () -> new ClientAssertion("")); + // Cast to String to resolve ambiguity between constructors assertThrows(NullPointerException.class, () -> - new ClientAssertion(null)); + new ClientAssertion((String) null)); } @Test @@ -107,4 +110,112 @@ void OnBehalfOf_TenantOverride() throws Exception { assertNotEquals(resultAppLevelTenant.accessToken(), resultRequestLevelTenant.accessToken()); verify(httpClientMock, times(2)).send(any()); } + + @Test + void testCredentialPrecedenceAndMixing() throws Exception { + DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + + // Create different credential types for testing + IClientCredential appLevelCredential = ClientCredentialFactory.createFromSecret("appLevelSecret"); + IClientCredential requestLevelSecret = ClientCredentialFactory.createFromSecret("requestLevelSecret"); + String assertionValue = "test_assertion_value"; + IClientCredential requestLevelAssertion = ClientCredentialFactory.createFromClientAssertion(assertionValue); + + // Create the application with the app-level credential + ConfidentialClientApplication cca = + ConfidentialClientApplication.builder("clientId", appLevelCredential) + .authority("https://login.microsoftonline.com/tenant") + .instanceDiscovery(false) + .validateAuthority(false) + .httpClient(httpClientMock) + .build(); + + // Set up the mock to check which credential is being used + when(httpClientMock.send(any(HttpRequest.class))).thenAnswer(invocation -> { + HttpRequest request = invocation.getArgument(0); + String requestBody = request.body(); + + // Check which credential type is included in the request and return a matching token + if (requestBody.contains("client_secret=requestLevelSecret")) { + HashMap responseParams = new HashMap<>(); + responseParams.put("access_token", "request_secret_token"); + return TestHelper.expectedResponse(HttpStatus.HTTP_OK, + TestHelper.getSuccessfulTokenResponse(responseParams)); + } else if (requestBody.contains("client_secret=appLevelSecret")) { + HashMap responseParams = new HashMap<>(); + responseParams.put("access_token", "app_secret_token"); + return TestHelper.expectedResponse(HttpStatus.HTTP_OK, + TestHelper.getSuccessfulTokenResponse(responseParams)); + } else if (requestBody.contains("client_assertion=" + assertionValue)) { + HashMap responseParams = new HashMap<>(); + responseParams.put("access_token", "assertion_token"); + return TestHelper.expectedResponse(HttpStatus.HTTP_OK, + TestHelper.getSuccessfulTokenResponse(responseParams)); + } + return null; + }); + + // Test 1: Request with same credential type (secret) at request level + ClientCredentialParameters parametersWithRequestSecret = + ClientCredentialParameters.builder(Collections.singleton("scope")) + .clientCredential(requestLevelSecret) + .skipCache(true) + .build(); + + IAuthenticationResult result1 = cca.acquireToken(parametersWithRequestSecret).get(); + assertEquals("request_secret_token", result1.accessToken(), + "Request-level secret should be used when provided"); + + // Test 2: Request with different credential type (assertion) at request level + ClientCredentialParameters parametersWithAssertion = + ClientCredentialParameters.builder(Collections.singleton("scope")) + .clientCredential(requestLevelAssertion) + .skipCache(true) + .build(); + + IAuthenticationResult result2 = cca.acquireToken(parametersWithAssertion).get(); + assertEquals("assertion_token", result2.accessToken(), + "Request-level assertion should be used when provided"); + + // Test 3: Request without credential specified should fall back to app-level + ClientCredentialParameters parametersWithoutCredential = + ClientCredentialParameters.builder(Collections.singleton("scope")) + .skipCache(true) + .build(); + + IAuthenticationResult result3 = cca.acquireToken(parametersWithoutCredential).get(); + assertEquals("app_secret_token", result3.accessToken(), + "App-level credential should be used when request-level credential is not provided"); + } + + @Test + void acquireTokenClientCredentials_Callback() { + // Create a counter to track how many times our callback is invoked + final AtomicInteger callCounter = new AtomicInteger(0); + + // Create a callable that returns a different value each time it's called + // by including the counter value in the returned string + Callable callable = () -> { + int currentCount = callCounter.incrementAndGet(); + return "assertion_" + currentCount; + }; + + // Create the client assertion using our callback + IClientAssertion credential = ClientCredentialFactory.createFromCallback(callable); + + // Each call to assertion() should invoke our callable + String assertion1 = credential.assertion(); + String assertion2 = credential.assertion(); + String assertion3 = credential.assertion(); + + // Verify the callable was called three times, generating three different assertions + assertEquals(3, callCounter.get(), "Callable should have been invoked exactly three times"); + assertEquals("assertion_1", assertion1, "First assertion value should match first call"); + assertEquals("assertion_2", assertion2, "Second assertion value should match second call"); + assertEquals("assertion_3", assertion3, "Third assertion value should match third call"); + + // Verify assertions are different from each other + assertNotEquals(assertion1, assertion2, "First and second assertions should be different"); + assertNotEquals(assertion2, assertion3, "Second and third assertions should be different"); + } }