Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update auth policy for DPG
  • Loading branch information
mccoyp committed Jan 30, 2024
commit 308df23d0be29447924fc95c07a083896ece498b
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
protocol again.
"""

from copy import deepcopy
import time
from typing import Any, Optional
from urllib.parse import urlparse
Expand All @@ -22,6 +23,7 @@
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.rest import HttpRequest

from . import http_challenge_cache as ChallengeCache
from .challenge_auth_policy import _enforce_tls, _update_challenge
Expand All @@ -39,6 +41,7 @@ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any
self._credential = credential
self._token: Optional[AccessToken] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None

async def on_request(self, request: PipelineRequest) -> None:
_enforce_tls(request)
Expand All @@ -60,12 +63,17 @@ async def on_request(self, request: PipelineRequest) -> None:

# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
# saving it for later. Key Vault will reject the request as unauthorized and respond with a challenge.
# on_challenge will parse that challenge, reattach any body removed here, authorize the request, and tell
# super to send it again.
if request.http_request.body:
request.context["key_vault_request_data"] = request.http_request.body
request.http_request.set_json_body(None)
request.http_request.headers["Content-Length"] = "0"
# on_challenge will parse that challenge, use the original request including the body, authorize the
# request, and tell super to send it again.
if request.http_request.content:
self._request_copy = request.http_request
bodiless_request = HttpRequest(
method=request.http_request.method,
url=request.http_request.url,
headers=deepcopy(request.http_request.headers),
)
bodiless_request.headers["Content-Length"] = "0"
request.http_request = bodiless_request


async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool:
Expand All @@ -89,8 +97,9 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
"See https://aka.ms/azsdk/blog/vault-uri for more information."
)

body = request.context.pop("key_vault_request_data", None)
request.http_request.set_text_body(body) # no-op when text is None
# If we had created a request copy in on_request, use it now to send along the original body content
if self._request_copy:
request.http_request = self._request_copy

# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
protocol again.
"""

from copy import deepcopy
import time
from typing import Any, Optional
from urllib.parse import urlparse
Expand All @@ -22,6 +23,7 @@
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.rest import HttpRequest

from .http_challenge import HttpChallenge
from . import http_challenge_cache as ChallengeCache
Expand Down Expand Up @@ -68,6 +70,7 @@ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) ->
self._credential = credential
self._token: Optional[AccessToken] = None
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
self._request_copy: Optional[HttpRequest] = None

def on_request(self, request: PipelineRequest) -> None:
_enforce_tls(request)
Expand All @@ -89,12 +92,17 @@ def on_request(self, request: PipelineRequest) -> None:

# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
# saving it for later. Key Vault will reject the request as unauthorized and respond with a challenge.
# on_challenge will parse that challenge, reattach any body removed here, authorize the request, and tell
# super to send it again.
if request.http_request.body:
request.context["key_vault_request_data"] = request.http_request.body
request.http_request.set_json_body(None)
request.http_request.headers["Content-Length"] = "0"
# on_challenge will parse that challenge, use the original request including the body, authorize the
# request, and tell super to send it again.
if request.http_request.content:
self._request_copy = request.http_request
bodiless_request = HttpRequest(
method=request.http_request.method,
url=request.http_request.url,
headers=deepcopy(request.http_request.headers),
)
bodiless_request.headers["Content-Length"] = "0"
request.http_request = bodiless_request

def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool:
try:
Expand All @@ -117,8 +125,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
"See https://aka.ms/azsdk/blog/vault-uri for more information."
)

body = request.context.pop("key_vault_request_data", None)
request.http_request.set_text_body(body) # no-op when text is None
# If we had created a request copy in on_request, use it now to send along the original body content
if self._request_copy:
request.http_request = self._request_copy

# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
Expand Down