Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
## 1.3.0-beta.3 (Unreleased)

### Features Added
* By default, credentials set client capability "CP1" to enable support for
[Continuous Access Evaluation (CAE)](https://docs.microsoft.com/azure/active-directory/develop/app-resilience-continuous-access-evaluation).
This indicates to Azure Active Directory that your application can handle CAE claims challenges.
You can disable this behavior by setting the environment variable "AZURE_IDENTITY_DISABLE_CP1" to "true".
* `InteractiveBrowserCredentialOptions.LoginHint` enables pre-populating the login
prompt with a username ([#15599](https://github.com/Azure/azure-sdk-for-go/pull/15599))

Expand Down
14 changes: 13 additions & 1 deletion sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ const (
tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names"
)

var (
// capability CP1 indicates the client application is capable of handling CAE claims challenges
cp1 = []string{"CP1"}
disableCP1 = strings.ToLower(os.Getenv("AZURE_IDENTITY_DISABLE_CP1")) == "true"
)

var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidentialClient, error) {
if !validTenantID(tenantID) {
return confidential.Client{}, errors.New(tenantIDValidationErr)
Expand All @@ -58,6 +64,9 @@ var getConfidentialClient = func(clientID, tenantID string, cred confidential.Cr
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
confidential.WithHTTPClient(newPipelineAdapter(co)),
}
if !disableCP1 {
o = append(o, confidential.WithClientCapabilities(cp1))
}
o = append(o, additionalOpts...)
if strings.ToLower(tenantID) == "adfs" {
o = append(o, confidential.WithInstanceDiscovery(false))
Expand All @@ -73,11 +82,14 @@ var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions,
if err != nil {
return public.Client{}, err
}

o := []public.Option{
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
public.WithClientCapabilities(cp1),
public.WithHTTPClient(newPipelineAdapter(co)),
}
if !disableCP1 {
o = append(o, public.WithClientCapabilities(cp1))
}
o = append(o, additionalOpts...)
if strings.ToLower(tenantID) == "adfs" {
o = append(o, public.WithInstanceDiscovery(false))
Expand Down
111 changes: 103 additions & 8 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"crypto/x509"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
Expand Down Expand Up @@ -136,16 +135,20 @@ func getTenantDiscoveryResponse(tenant string) []byte {

func validateX5C(t *testing.T, certs []*x509.Certificate) mock.ResponsePredicate {
return func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
err := req.ParseForm()
if err != nil {
t.Fatal("Expected a request with the JWT in the body.")
t.Fatal("expected a form body")
}
bodystr := string(body)
kvps := strings.Split(bodystr, "&")
assertion := strings.Split(kvps[0], "=")
token, _ := jwt.Parse(assertion[1], nil)
assertion, ok := req.PostForm["client_assertion"]
if !ok {
t.Fatal("expected a client_assertion field")
}
if len(assertion) != 1 {
t.Fatalf(`unexpected client_assertion "%v"`, assertion)
}
token, _ := jwt.Parse(assertion[0], nil)
if token == nil {
t.Fatalf("Failed to parse the JWT token: %s.", assertion[1])
t.Fatalf("failed to parse the assertion: %s", assertion)
}
if v, ok := token.Header["x5c"].([]any); !ok {
t.Fatal("missing x5c header")
Expand Down Expand Up @@ -525,6 +528,98 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
}
}

func TestClaims(t *testing.T) {
claim := `"test":"pass"`
for _, test := range []struct {
ctor func(azcore.ClientOptions) (azcore.TokenCredential, error)
name string
}{
{
name: credNameAssertion,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := ClientAssertionCredentialOptions{ClientOptions: co}
return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o)
},
},
{
name: credNameCert,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := ClientCertificateCredentialOptions{ClientOptions: co}
return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o)
},
},
{
name: credNameDeviceCode,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := DeviceCodeCredentialOptions{
ClientOptions: co,
UserPrompt: func(context.Context, DeviceCodeMessage) error { return nil },
}
return NewDeviceCodeCredential(&o)
},
},
{
name: credNameOBO,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := OnBehalfOfCredentialOptions{ClientOptions: co}
return NewOnBehalfOfCredentialFromSecret(fakeTenantID, fakeClientID, "assertion", fakeSecret, &o)
},
},
{
name: credNameSecret,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := ClientSecretCredentialOptions{ClientOptions: co}
return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o)
},
},
{
name: credNameUserPassword,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := UsernamePasswordCredentialOptions{ClientOptions: co}
return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o)
},
},
} {
t.Run(test.name, func(t *testing.T) {
reqs := 0
sts := mockSTS{
tokenRequestCallback: func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Error(err)
}
reqs += 1
// Both requests should specify CP1. The second GetToken call specifies claims we should
// find in the token request. We check only for the expected substring because MSAL is
// responsible for formatting claims.
actual := r.Form["claims"]
if len(actual) != 1 || !strings.Contains(actual[0], "CP1") {
t.Fatalf(`unexpected claims "%v"`, actual)
}
if reqs == 2 {
if !strings.Contains(strings.ReplaceAll(actual[0], " ", ""), claim) {
t.Fatalf(`unexpected claims "%v"`, actual)
}
}
},
}
o := azcore.ClientOptions{Transport: &sts}
cred, err := test.ctor(o)
if err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"A"}}); err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Claims: fmt.Sprintf("{%s}", claim), Scopes: []string{"B"}}); err != nil {
t.Fatal(err)
}
if reqs != 2 {
t.Fatalf("expected %d token requests, got %d", 2, reqs)
}
})
}
}

func TestResolveTenant(t *testing.T) {
defaultTenant := "default-tenant"
otherTenant := "other-tenant"
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/client_assertion_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ func (c *ClientAssertionCredential) GetToken(ctx context.Context, opts policy.To
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccessImpl(c.name, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(c.name, err)
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/client_certificate_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts policy.
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameCert, err)
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/client_secret_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ func (c *ClientSecretCredential) GetToken(ctx context.Context, opts policy.Token
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameSecret, err)
}
Expand Down
8 changes: 6 additions & 2 deletions sdk/azidentity/device_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,15 @@ func (c *DeviceCodeCredential) GetToken(ctx context.Context, opts policy.TokenRe
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithSilentAccount(c.account),
public.WithTenantID(tenant),
)
if err == nil {
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
dc, err := c.client.AcquireTokenByDeviceCode(ctx, opts.Scopes, public.WithTenantID(tenant))
dc, err := c.client.AcquireTokenByDeviceCode(ctx, opts.Scopes, public.WithClaims(opts.Claims), public.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameDeviceCode, err)
}
Expand Down
7 changes: 6 additions & 1 deletion sdk/azidentity/interactive_browser_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,18 @@ func (c *InteractiveBrowserCredential) GetToken(ctx context.Context, opts policy
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithSilentAccount(c.account),
public.WithTenantID(tenant),
)
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenInteractive(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithLoginHint(c.options.LoginHint),
public.WithRedirectURI(c.options.RedirectURL),
public.WithTenantID(tenant),
Expand Down
5 changes: 4 additions & 1 deletion sdk/azidentity/on_behalf_of_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func (o *OnBehalfOfCredential) GetToken(ctx context.Context, opts policy.TokenRe
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes,
confidential.WithClaims(opts.Claims),
confidential.WithTenantID(tenant),
)
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameOBO, err)
}
Expand Down
8 changes: 6 additions & 2 deletions sdk/azidentity/username_password_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,16 @@ func (c *UsernamePasswordCredential) GetToken(ctx context.Context, opts policy.T
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithSilentAccount(c.account),
public.WithTenantID(tenant),
)
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
ar, err = c.client.AcquireTokenByUsernamePassword(ctx, opts.Scopes, c.username, c.password, public.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByUsernamePassword(ctx, opts.Scopes, c.username, c.password, public.WithClaims(opts.Claims), public.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameUserPassword, err)
}
Expand Down