Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -199,74 +199,125 @@ public void createAadTokenCredential() throws InterruptedException {
}

@Test(groups = { "long-emulator" }, timeOut = 10 * TIMEOUT)
public void testAadScopeOverride() throws Exception {
CosmosAsyncClient setupClient = null;
CosmosAsyncClient aadClient = null;
String containerName = UUID.randomUUID().toString();
String overrideScope = "https://cosmos.azure.com/.default";
public void overrideScope_only_noFallback_onSuccess() throws Exception {
CosmosAsyncClient client = null;
ScopeRecorder.clear();

final String overrideScope = "https://my.custom.scope/.default";
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, overrideScope);

try {
setupClient = new CosmosClientBuilder()
TokenCredential cred = new AadSimpleEmulatorTokenCredential(TestConfigurations.MASTER_KEY);

client = new CosmosClientBuilder()
.endpoint(TestConfigurations.HOST)
.key(TestConfigurations.MASTER_KEY)
.credential(cred)
.buildAsyncClient();

setupClient.createDatabase(databaseId).block();
setupClient.getDatabase(databaseId).createContainer(containerName, PARTITION_KEY_PATH).block();
} finally {
if (setupClient != null) {
safeClose(setupClient);
client.readAllDatabases().byPage().blockFirst();

java.util.List<String> scopes = ScopeRecorder.all();
assert scopes.size() >= 1 : "Expected at least one AAD call";
for (String s : scopes) {
assert overrideScope.equals(s) : "Expected only override scope; saw: " + scopes;
}
} finally {
if (client != null) safeClose(client);
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, Configs.DEFAULT_AAD_SCOPE_OVERRIDE);
}
}

Thread.sleep(TIMEOUT);
@Test(groups = { "long-emulator"}, timeOut = 10 * TIMEOUT)
public void overrideScope_authError_noFallback() throws Exception {
CosmosAsyncClient client = null;
ScopeRecorder.clear();

final String overrideScope = "https://my.custom.scope/.default";
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, overrideScope);

TokenCredential emulatorCredential =
new AadSimpleEmulatorTokenCredential(TestConfigurations.MASTER_KEY);
try {
TokenCredential cred = new AlwaysFail500011Credential();

aadClient = new CosmosClientBuilder()
.endpoint(TestConfigurations.HOST)
.credential(emulatorCredential)
.buildAsyncClient();
client = new CosmosClientBuilder()
.endpoint(TestConfigurations.HOST)
.credential(cred)
.buildAsyncClient();

try {
CosmosAsyncContainer container = aadClient
.getDatabase(databaseId)
.getContainer(containerName);
try {
client.readAllDatabases().byPage().blockFirst();
assert false : "Expected an auth failure with override scope";
} catch (Exception ex) {
// Only the override scope should have been attempted; no fallback allowed
java.util.List<String> scopes = ScopeRecorder.all();
assert scopes.size() >= 1 : "Expected at least one scope attempt";
for (String s : scopes) {
assert overrideScope.equals(s) : "No fallback allowed in override mode; saw: " + scopes;
}
}
} finally {
if (client != null) safeClose(client);
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, Configs.DEFAULT_AAD_SCOPE_OVERRIDE);
}
}

String itemId = UUID.randomUUID().toString();
String pk = UUID.randomUUID().toString();
ItemSample item = getDocumentDefinition(itemId, pk);
@Test(groups = { "long-emulator"}, timeOut = 10 * TIMEOUT)
public void accountScope_only_whenNoOverride_andNoAuthFailure() throws Exception {
CosmosAsyncClient client = null;
ScopeRecorder.clear();
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, Configs.DEFAULT_AAD_SCOPE_OVERRIDE);

container.createItem(item).block();
java.net.URI ep = new java.net.URI(TestConfigurations.HOST);
String accountScope = ep.getScheme() + "://" + ep.getHost() + "/.default";

List<String> scopes = AadSimpleEmulatorTokenCredential.getLastScopes();
assert scopes != null && scopes.size() == 1;
assert overrideScope.equals(scopes.get(0));
try {
TokenCredential cred = new AadSimpleEmulatorTokenCredential(TestConfigurations.MASTER_KEY);

container.deleteItem(item.id, new PartitionKey(item.mypk)).block();
} finally {
try {
CosmosAsyncClient cleanupClient = new CosmosClientBuilder()
.endpoint(TestConfigurations.HOST)
.key(TestConfigurations.MASTER_KEY)
.buildAsyncClient();
try {
cleanupClient.getDatabase(databaseId).delete().block();
} finally {
safeClose(cleanupClient);
}
} finally {
if (aadClient != null) {
safeClose(aadClient);
}
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, Configs.DEFAULT_AAD_SCOPE_OVERRIDE);
client = new CosmosClientBuilder()
.endpoint(TestConfigurations.HOST)
.credential(cred)
.buildAsyncClient();

client.readAllDatabases().byPage().blockFirst();

java.util.List<String> scopes = ScopeRecorder.all();
assert scopes.size() >= 1 : "Expected at least one AAD call";
for (String s : scopes) {
assert accountScope.equals(s) : "Expected only account scope; saw: " + scopes;
}
} finally {
if (client != null) safeClose(client);
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, Configs.DEFAULT_AAD_SCOPE_OVERRIDE);
}
}

Thread.sleep(SHUTDOWN_TIMEOUT);
@Test(groups = { "long-emulator"}, timeOut = 10 * TIMEOUT)
public void accountScope_fallbackToCosmosScope_onAadSts500011() throws Exception {
CosmosAsyncClient client = null;
ScopeRecorder.clear();
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, Configs.DEFAULT_AAD_SCOPE_OVERRIDE);

java.net.URI ep = new java.net.URI(TestConfigurations.HOST);
String accountScope = ep.getScheme() + "://" + ep.getHost() + "/.default";
String fallbackScope = "https://cosmos.azure.com/.default";

try {
// Fail on account scope with AADSTS500011
TokenCredential cred = new AccountThenFallbackCredential(TestConfigurations.MASTER_KEY, accountScope);

client = new CosmosClientBuilder()
.endpoint(TestConfigurations.HOST)
.credential(cred)
.buildAsyncClient();

client.readAllDatabases().byPage().blockFirst();

java.util.List<String> scopes = ScopeRecorder.all();
assert scopes.contains(accountScope) : "Expected primary account scope attempt; saw: " + scopes;
assert scopes.contains(fallbackScope) : "Expected fallback to cosmos public scope; saw: " + scopes;
} finally {
if (client != null) safeClose(client);
setEnv(Configs.AAD_SCOPE_OVERRIDE_VARIABLE, Configs.DEFAULT_AAD_SCOPE_OVERRIDE);
}
}

@SuppressWarnings({"unchecked", "rawtypes"})
Expand Down Expand Up @@ -298,6 +349,65 @@ private static void setEnv(String key, String value) throws Exception {
}
}

// Records all scopes used during the test run (append-only).
static final class ScopeRecorder {
private static final java.util.concurrent.CopyOnWriteArrayList<String> SEEN = new java.util.concurrent.CopyOnWriteArrayList<>();
static void clear() { SEEN.clear(); }
static void record(TokenRequestContext ctx) {
java.util.List<String> s = ctx.getScopes();
if (s != null) SEEN.addAll(s);
}
static java.util.List<String> all() { return new java.util.ArrayList<>(SEEN); }
}

/**
* Always fails with an AADSTS500011 message for any scope.
* Used to prove that the "override scope" path does NOT fallback.
*/
static class AlwaysFail500011Credential implements TokenCredential {
@Override
public Mono<AccessToken> getToken(TokenRequestContext tokenRequestContext) {
ScopeRecorder.record(tokenRequestContext);
return Mono.error(new RuntimeException("AADSTS500011: Application <id> was not found in the directory"));
}
}

static class AccountThenFallbackCredential implements TokenCredential {
private final AadSimpleEmulatorTokenCredential emulatorIssuer;
private final String accountScope;
private final String cosmosPublicScope = "https://cosmos.azure.com/.default";
private final java.util.concurrent.atomic.AtomicInteger calls = new java.util.concurrent.atomic.AtomicInteger(0);

AccountThenFallbackCredential(String emulatorKey, String accountScope) {
this.emulatorIssuer = new AadSimpleEmulatorTokenCredential(emulatorKey);
this.accountScope = accountScope;
}

@Override
public Mono<AccessToken> getToken(TokenRequestContext tokenRequestContext) {
ScopeRecorder.record(tokenRequestContext);

String scope = tokenRequestContext.getScopes() != null && !tokenRequestContext.getScopes().isEmpty()
? tokenRequestContext.getScopes().get(0)
: "";

int n = calls.incrementAndGet();

// Fail on the first attempt if it is the account scope, to trigger fallback
if (n == 1 && scope.equals(accountScope)) {
return Mono.error(new RuntimeException("AADSTS500011: Application <id> was not found in the directory"));
}

// When SDK retries with the cosmos public scope, succeed with a valid emulator token
if (scope.equals(cosmosPublicScope)) {
return emulatorIssuer.getToken(tokenRequestContext);
}

// If anything unexpected happens, fail loudly so the test points to the issue
return Mono.error(new IllegalStateException("Unexpected scope or call ordering. Scope=" + scope + " call=" + n));
}
}

private ItemSample getDocumentDefinition(String itemId, String partitionKeyValue) {
ItemSample itemSample = new ItemSample();
itemSample.id = itemId;
Expand Down Expand Up @@ -329,11 +439,6 @@ static class AadSimpleEmulatorTokenCredential implements TokenCredential {
private final String AAD_HEADER_COSMOS_EMULATOR = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}";
private final String AAD_CLAIM_COSMOS_EMULATOR_FORMAT = "{\"aud\":\"https://localhost.localhost\",\"iss\":\"https://sts.fake-issuer.net/7b1999a1-dfd7-440e-8204-00170979b984\",\"iat\":%d,\"nbf\":%d,\"exp\":%d,\"aio\":\"\",\"appid\":\"localhost\",\"appidacr\":\"1\",\"idp\":\"https://localhost:8081/\",\"oid\":\"96313034-4739-43cb-93cd-74193adbe5b6\",\"rh\":\"\",\"sub\":\"localhost\",\"tid\":\"EmulatorFederation\",\"uti\":\"\",\"ver\":\"1.0\",\"scp\":\"user_impersonation\",\"groups\":[\"7ce1d003-4cb3-4879-b7c5-74062a35c66e\",\"e99ff30c-c229-4c67-ab29-30a6aebc3e58\",\"5549bb62-c77b-4305-bda9-9ec66b85d9e4\",\"c44fd685-5c58-452c-aaf7-13ce75184f65\",\"be895215-eab5-43b7-9536-9ef8fe130330\"]}";

private static volatile List<String> lastScopes = Collections.emptyList();

public static List<String> getLastScopes() {
return lastScopes;
}
public AadSimpleEmulatorTokenCredential(String emulatorKey) {
if (emulatorKey == null || emulatorKey.isEmpty()) {
throw new IllegalArgumentException("emulatorKey");
Expand All @@ -344,10 +449,8 @@ public AadSimpleEmulatorTokenCredential(String emulatorKey) {

@Override
public Mono<AccessToken> getToken(TokenRequestContext tokenRequestContext) {
List<String> scopes = tokenRequestContext.getScopes(); // List<String>, not String[]
lastScopes = (scopes != null && !scopes.isEmpty())
? new ArrayList<>(scopes)
: Collections.emptyList();
// Record scopes for verification in tests
AadAuthorizationTests.ScopeRecorder.record(tokenRequestContext);

String aadToken = emulatorKey_based_AAD_String();
return Mono.just(new AccessToken(aadToken, OffsetDateTime.now().plusHours(2)));
Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#### Features Added
* Added `ThroughputBucket` support for throughput control. - [PR 46042](https://github.com/Azure/azure-sdk-for-java/pull/46042)
* AAD Auth: Adds a fallback mechanism for AAD audience scope. - [PR 46637](https://github.com/Azure/azure-sdk-for-java/pull/46637)

#### Bugs Fixed
* Fixed 404/1002 for query when container recreated with same name. - [PR 45930](https://github.com/Azure/azure-sdk-for-java/pull/45930)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,5 @@ public static final class QueryExecutionContext {
}

public static final int QUERYPLAN_CACHE_SIZE = 5000;
public static final String AAD_DEFAULT_SCOPE = "https://cosmos.azure.com/.default";
}
Original file line number Diff line number Diff line change
Expand Up @@ -533,17 +533,68 @@ private RxDocumentClientImpl(URI serviceEndpoint,
hasAuthKeyResourceToken = false;
this.authorizationTokenType = AuthorizationTokenType.PrimaryMasterKey;
this.authorizationTokenProvider = new BaseAuthorizationTokenProvider(this.credential);
} else {
}else {
hasAuthKeyResourceToken = false;
this.authorizationTokenProvider = null;
if (tokenCredential != null) {
String scopeOverride = Configs.getAadScopeOverride();
String defaultScope = serviceEndpoint.getScheme() + "://" + serviceEndpoint.getHost() + "/.default";
String scopeToUse = (scopeOverride != null && !scopeOverride.isEmpty()) ? scopeOverride : defaultScope;
String accountScope = serviceEndpoint.getScheme() + "://" + serviceEndpoint.getHost() + "/.default";

if (scopeOverride != null && !scopeOverride.isEmpty()) {
// Use only the override scope; no fallback.
this.tokenCredentialScopes = new String[] { scopeOverride };

this.tokenCredentialCache = new SimpleTokenCache(() -> {
final String primaryScope = this.tokenCredentialScopes[0];
final TokenRequestContext ctx = new TokenRequestContext().addScopes(primaryScope);
return this.tokenCredential.getToken(ctx)
.doOnNext(t -> {
if (logger.isInfoEnabled()) {
logger.info("AAD token: acquired using override scope: {}", primaryScope);
}
});
});
} else {
// Account scope with fallback to default scope on AADSTS500011 error
this.tokenCredentialScopes = new String[] { accountScope, Constants.AAD_DEFAULT_SCOPE };

this.tokenCredentialCache = new SimpleTokenCache(() -> {
final String primaryScope = this.tokenCredentialScopes[0];
final String fallbackScope = this.tokenCredentialScopes[1];

this.tokenCredentialScopes = new String[] { scopeToUse };
this.tokenCredentialCache = new SimpleTokenCache(() -> this.tokenCredential
.getToken(new TokenRequestContext().addScopes(this.tokenCredentialScopes)));
final TokenRequestContext primaryCtx = new TokenRequestContext().addScopes(primaryScope);

return this.tokenCredential.getToken(primaryCtx)
.doOnNext(t -> {
if (logger.isInfoEnabled()) {
logger.info("AAD token: acquired using account scope: {}", primaryScope);
}
})
.onErrorResume(error -> {
final Throwable root = reactor.core.Exceptions.unwrap(error);
final String messageText = (root.getMessage() != null) ? root.getMessage() : "";
final boolean isAadAppNotFound = messageText.contains("AADSTS500011");

if (!isAadAppNotFound) {
return Mono.error(error);
}

if (logger.isWarnEnabled()) {
logger.warn(
"AAD token: account scope failed with AADSTS500011; retrying with fallback scope: {}",
fallbackScope);
}

final TokenRequestContext fallbackCtx = new TokenRequestContext().addScopes(fallbackScope);
return this.tokenCredential.getToken(fallbackCtx)
.doOnNext(t -> {
if (logger.isInfoEnabled()) {
logger.info("AAD token: acquired using fallback scope: {}", fallbackScope);
}
});
});
});
}
this.authorizationTokenType = AuthorizationTokenType.AadToken;
}
}
Expand Down
Loading