Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Consolidate common MSAL options
  • Loading branch information
chlowell committed Oct 6, 2022
commit ec84e3f973def7941cb7f0712b88500e429a3449
31 changes: 31 additions & 0 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,37 @@ const (
tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names"
)

func getConfidentialClient(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidential.Client, error) {
if !validTenantID(tenantID) {
return confidential.Client{}, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(co.Cloud)
if err != nil {
return confidential.Client{}, err
}
o := []confidential.Option{
confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
confidential.WithHTTPClient(newPipelineAdapter(co)),
}
o = append(o, additionalOpts...)
return confidential.New(clientID, cred, o...)
}

func getPublicClient(clientID, tenantID string, co *azcore.ClientOptions) (public.Client, error) {
if !validTenantID(tenantID) {
return public.Client{}, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(co.Cloud)
if err != nil {
return public.Client{}, err
}
return public.New(clientID,
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
public.WithHTTPClient(newPipelineAdapter(co)),
)
}

// setAuthorityHost initializes the authority host for credentials. Precedence is:
// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user
// 2. value of AZURE_AUTHORITY_HOST
Expand Down
15 changes: 1 addition & 14 deletions sdk/azidentity/client_assertion_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ package azidentity
import (
"context"
"errors"
"os"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

Expand All @@ -39,26 +37,15 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c
if getAssertion == nil {
return nil, errors.New("getAssertion must be a function that returns assertions")
}
if !validTenantID(tenantID) {
return nil, errors.New(tenantIDValidationErr)
}
if options == nil {
options = &ClientAssertionCredentialOptions{}
}
authorityHost, err := setAuthorityHost(options.Cloud)
if err != nil {
return nil, err
}
cred := confidential.NewCredFromAssertionCallback(
func(ctx context.Context, _ confidential.AssertionRequestOptions) (string, error) {
return getAssertion(ctx)
},
)
c, err := confidential.New(clientID, cred,
confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
confidential.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)),
)
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions)
if err != nil {
return nil, err
}
Expand Down
17 changes: 2 additions & 15 deletions sdk/azidentity/client_certificate_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import (
"crypto/x509"
"encoding/pem"
"errors"
"os"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
"golang.org/x/crypto/pkcs12"
)
Expand All @@ -43,29 +41,18 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x
if len(certs) == 0 {
return nil, errors.New("at least one certificate is required")
}
if !validTenantID(tenantID) {
return nil, errors.New(tenantIDValidationErr)
}
if options == nil {
options = &ClientCertificateCredentialOptions{}
}
authorityHost, err := setAuthorityHost(options.Cloud)
if err != nil {
return nil, err
}
cred, err := confidential.NewCredFromCertChain(certs, key)
if err != nil {
return nil, err
}
o := []confidential.Option{
confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
confidential.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)),
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
}
var o []confidential.Option
if options.SendCertificateChain {
o = append(o, confidential.WithX5C())
}
c, err := confidential.New(clientID, cred, o...)
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, o...)
if err != nil {
return nil, err
}
Expand Down
15 changes: 1 addition & 14 deletions sdk/azidentity/client_secret_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ package azidentity
import (
"context"
"errors"
"os"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

Expand All @@ -31,25 +29,14 @@ type ClientSecretCredential struct {

// NewClientSecretCredential constructs a ClientSecretCredential. Pass nil for options to accept defaults.
func NewClientSecretCredential(tenantID string, clientID string, clientSecret string, options *ClientSecretCredentialOptions) (*ClientSecretCredential, error) {
if !validTenantID(tenantID) {
return nil, errors.New(tenantIDValidationErr)
}
if options == nil {
options = &ClientSecretCredentialOptions{}
}
authorityHost, err := setAuthorityHost(options.Cloud)
if err != nil {
return nil, err
}
cred, err := confidential.NewCredFromSecret(clientSecret)
if err != nil {
return nil, err
}
c, err := confidential.New(clientID, cred,
confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
confidential.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)),
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
)
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions)
if err != nil {
return nil, err
}
Expand Down
13 changes: 1 addition & 12 deletions sdk/azidentity/device_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
)

Expand Down Expand Up @@ -79,17 +78,7 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC
cp = *options
}
cp.init()
if !validTenantID(cp.TenantID) {
return nil, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(cp.Cloud)
if err != nil {
return nil, err
}
c, err := public.New(cp.ClientID,
public.WithAuthority(runtime.JoinPaths(authorityHost, cp.TenantID)),
public.WithHTTPClient(newPipelineAdapter(&cp.ClientOptions)),
)
c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions)
if err != nil {
return nil, err
}
Expand Down
13 changes: 1 addition & 12 deletions sdk/azidentity/interactive_browser_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
)

Expand Down Expand Up @@ -56,17 +55,7 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption
cp = *options
}
cp.init()
if !validTenantID(cp.TenantID) {
return nil, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(cp.Cloud)
if err != nil {
return nil, err
}
c, err := public.New(cp.ClientID,
public.WithAuthority(runtime.JoinPaths(authorityHost, cp.TenantID)),
public.WithHTTPClient(newPipelineAdapter(&cp.ClientOptions)),
)
c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions)
if err != nil {
return nil, err
}
Expand Down
13 changes: 1 addition & 12 deletions sdk/azidentity/username_password_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
)

Expand All @@ -37,20 +36,10 @@ type UsernamePasswordCredential struct {
// NewUsernamePasswordCredential creates a UsernamePasswordCredential. clientID is the ID of the application the user
// will authenticate to. Pass nil for options to accept defaults.
func NewUsernamePasswordCredential(tenantID string, clientID string, username string, password string, options *UsernamePasswordCredentialOptions) (*UsernamePasswordCredential, error) {
if !validTenantID(tenantID) {
return nil, errors.New(tenantIDValidationErr)
}
if options == nil {
options = &UsernamePasswordCredentialOptions{}
}
authorityHost, err := setAuthorityHost(options.Cloud)
if err != nil {
return nil, err
}
c, err := public.New(clientID,
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
public.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)),
)
c, err := getPublicClient(clientID, tenantID, &options.ClientOptions)
if err != nil {
return nil, err
}
Expand Down