Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
## 1.3.0-beta.3 (Unreleased)

### Features Added
* By default, credentials set client capability "CP1" to enable support for
[Continuous Access Evaluation (CAE)](https://docs.microsoft.com/azure/active-directory/develop/app-resilience-continuous-access-evaluation).
This indicates to Azure Active Directory that your application can handle CAE claims challenges.
You can disable this behavior by setting the environment variable "AZURE_IDENTITY_DISABLE_CP1" to "true".
* `InteractiveBrowserCredentialOptions.LoginHint` enables pre-populating the login
prompt with a username ([#15599](https://github.com/Azure/azure-sdk-for-go/pull/15599))

Expand Down
13 changes: 12 additions & 1 deletion sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ const (
tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names"
)

var (
// capability CP1 indicates the client application is capable of handling CAE claims challenges
cp1 = []string{"CP1"}
disableCP1 = strings.ToLower(os.Getenv("AZURE_IDENTITY_DISABLE_CP1")) == "true"
)

var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidentialClient, error) {
if !validTenantID(tenantID) {
return confidential.Client{}, errors.New(tenantIDValidationErr)
Expand All @@ -58,6 +64,9 @@ var getConfidentialClient = func(clientID, tenantID string, cred confidential.Cr
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
confidential.WithHTTPClient(newPipelineAdapter(co)),
}
if !disableCP1 {
o = append(o, confidential.WithClientCapabilities(cp1))
}
o = append(o, additionalOpts...)
if strings.ToLower(tenantID) == "adfs" {
o = append(o, confidential.WithInstanceDiscovery(false))
Expand All @@ -73,11 +82,13 @@ var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions,
if err != nil {
return public.Client{}, err
}

o := []public.Option{
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
public.WithHTTPClient(newPipelineAdapter(co)),
}
if !disableCP1 {
o = append(o, public.WithClientCapabilities(cp1))
}
o = append(o, additionalOpts...)
if strings.ToLower(tenantID) == "adfs" {
o = append(o, public.WithInstanceDiscovery(false))
Expand Down
120 changes: 112 additions & 8 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"crypto/x509"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
Expand Down Expand Up @@ -136,16 +135,20 @@ func getTenantDiscoveryResponse(tenant string) []byte {

func validateX5C(t *testing.T, certs []*x509.Certificate) mock.ResponsePredicate {
return func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
err := req.ParseForm()
if err != nil {
t.Fatal("Expected a request with the JWT in the body.")
t.Fatal("expected a form body")
}
bodystr := string(body)
kvps := strings.Split(bodystr, "&")
assertion := strings.Split(kvps[0], "=")
token, _ := jwt.Parse(assertion[1], nil)
assertion, ok := req.PostForm["client_assertion"]
if !ok {
t.Fatal("expected a client_assertion field")
}
if len(assertion) != 1 {
t.Fatalf(`unexpected client_assertion "%v"`, assertion)
}
token, _ := jwt.Parse(assertion[0], nil)
if token == nil {
t.Fatalf("Failed to parse the JWT token: %s.", assertion[1])
t.Fatalf("failed to parse the assertion: %s", assertion)
}
if v, ok := token.Header["x5c"].([]any); !ok {
t.Fatal("missing x5c header")
Expand Down Expand Up @@ -525,6 +528,107 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
}
}

func TestClaims(t *testing.T) {
realCP1 := disableCP1
t.Cleanup(func() { disableCP1 = realCP1 })
claim := `"test":"pass"`
for _, test := range []struct {
ctor func(azcore.ClientOptions) (azcore.TokenCredential, error)
name string
}{
{
name: credNameAssertion,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := ClientAssertionCredentialOptions{ClientOptions: co}
return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o)
},
},
{
name: credNameCert,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := ClientCertificateCredentialOptions{ClientOptions: co}
return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o)
},
},
{
name: credNameDeviceCode,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := DeviceCodeCredentialOptions{
ClientOptions: co,
UserPrompt: func(context.Context, DeviceCodeMessage) error { return nil },
}
return NewDeviceCodeCredential(&o)
},
},
{
name: credNameOBO,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := OnBehalfOfCredentialOptions{ClientOptions: co}
return NewOnBehalfOfCredentialFromSecret(fakeTenantID, fakeClientID, "assertion", fakeSecret, &o)
},
},
{
name: credNameSecret,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := ClientSecretCredentialOptions{ClientOptions: co}
return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o)
},
},
{
name: credNameUserPassword,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
o := UsernamePasswordCredentialOptions{ClientOptions: co}
return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o)
},
},
} {
for _, d := range []bool{true, false} {
name := test.name
if d {
name += " disableCP1"
}
t.Run(name, func(t *testing.T) {
disableCP1 = d
reqs := 0
sts := mockSTS{
tokenRequestCallback: func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Error(err)
}
reqs++
// If the disableCP1 flag isn't set, both requests should specify CP1. The second
// GetToken call specifies claims we should find in the following token request.
// We check only for substrings because MSAL is responsible for formatting claims.
actual := fmt.Sprint(r.Form["claims"])
if strings.Contains(actual, "CP1") == disableCP1 {
t.Fatalf(`unexpected claims "%v"`, actual)
}
if reqs == 2 {
if !strings.Contains(strings.ReplaceAll(actual, " ", ""), claim) {
t.Fatalf(`unexpected claims "%v"`, actual)
}
}
},
}
o := azcore.ClientOptions{Transport: &sts}
cred, err := test.ctor(o)
if err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"A"}}); err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Claims: fmt.Sprintf("{%s}", claim), Scopes: []string{"B"}}); err != nil {
t.Fatal(err)
}
if reqs != 2 {
t.Fatalf("expected %d token requests, got %d", 2, reqs)
}
})
}
}
}

func TestResolveTenant(t *testing.T) {
defaultTenant := "default-tenant"
otherTenant := "other-tenant"
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/client_assertion_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ func (c *ClientAssertionCredential) GetToken(ctx context.Context, opts policy.To
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccessImpl(c.name, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(c.name, err)
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/client_certificate_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts policy.
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameCert, err)
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/azidentity/client_secret_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ func (c *ClientSecretCredential) GetToken(ctx context.Context, opts policy.Token
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameSecret, err)
}
Expand Down
8 changes: 6 additions & 2 deletions sdk/azidentity/device_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,15 @@ func (c *DeviceCodeCredential) GetToken(ctx context.Context, opts policy.TokenRe
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithSilentAccount(c.account),
public.WithTenantID(tenant),
)
if err == nil {
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
dc, err := c.client.AcquireTokenByDeviceCode(ctx, opts.Scopes, public.WithTenantID(tenant))
dc, err := c.client.AcquireTokenByDeviceCode(ctx, opts.Scopes, public.WithClaims(opts.Claims), public.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameDeviceCode, err)
}
Expand Down
7 changes: 6 additions & 1 deletion sdk/azidentity/interactive_browser_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,18 @@ func (c *InteractiveBrowserCredential) GetToken(ctx context.Context, opts policy
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithSilentAccount(c.account),
public.WithTenantID(tenant),
)
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

ar, err = c.client.AcquireTokenInteractive(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithLoginHint(c.options.LoginHint),
public.WithRedirectURI(c.options.RedirectURL),
public.WithTenantID(tenant),
Expand Down
5 changes: 4 additions & 1 deletion sdk/azidentity/on_behalf_of_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func (o *OnBehalfOfCredential) GetToken(ctx context.Context, opts policy.TokenRe
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes, confidential.WithTenantID(tenant))
ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes,
confidential.WithClaims(opts.Claims),
confidential.WithTenantID(tenant),
)
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameOBO, err)
}
Expand Down
8 changes: 6 additions & 2 deletions sdk/azidentity/username_password_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,16 @@ func (c *UsernamePasswordCredential) GetToken(ctx context.Context, opts policy.T
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant))
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes,
public.WithClaims(opts.Claims),
public.WithSilentAccount(c.account),
public.WithTenantID(tenant),
)
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
ar, err = c.client.AcquireTokenByUsernamePassword(ctx, opts.Scopes, c.username, c.password, public.WithTenantID(tenant))
ar, err = c.client.AcquireTokenByUsernamePassword(ctx, opts.Scopes, c.username, c.password, public.WithClaims(opts.Claims), public.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameUserPassword, err)
}
Expand Down