Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
copy shared code to azure-keyvault-secrets
  • Loading branch information
chlowell committed Jul 17, 2019
commit 3fd7bd1c1964af49b674d8babe9a4816dfe17e92
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
# Licensed under the MIT License.
# ------------------------------------
from ._client import SecretClient
from ._models import Secret, SecretAttributes, DeletedSecret

__all__ = ["SecretClient"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError

from ._internal import _KeyVaultClientBase
from ._shared import KeyVaultClientBase
from ._models import Secret, DeletedSecret, SecretAttributes


class SecretClient(_KeyVaultClientBase):
class SecretClient(KeyVaultClientBase):
"""SecretClient is a high-level interface for managing a vault's secrets.

Example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import datetime
from typing import Any, Dict, Mapping, Optional
from ._generated.v7_0 import models
from ._internal import _parse_vault_id

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any, Dict, Mapping, Optional

from ._shared import parse_vault_id
from ._shared._generated.v7_0 import models


class SecretAttributes(object):
Expand All @@ -16,7 +23,7 @@ def __init__(self, attributes, vault_id, **kwargs):
# type: (models.SecretAttributes, str, Mapping[str, Any]) -> None
self._attributes = attributes
self._id = vault_id
self._vault_id = _parse_vault_id(vault_id)
self._vault_id = parse_vault_id(vault_id)
self._content_type = kwargs.get("content_type", None)
self._key_id = kwargs.get("key_id", None)
self._managed = kwargs.get("managed", None)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from collections import namedtuple

try:
import urllib.parse as parse
except ImportError:
# pylint:disable=import-error
import urlparse as parse # type: ignore

from .challenge_auth_policy import ChallengeAuthPolicy, ChallengeAuthPolicyBase
from .client_base import KeyVaultClientBase
from .http_challenge import HttpChallenge
from . import http_challenge_cache as HttpChallengeCache

__all__ = [
"ChallengeAuthPolicy",
"ChallengeAuthPolicyBase",
"HttpChallenge",
"HttpChallengeCache",
"KeyVaultClientBase",
]

_VaultId = namedtuple("VaultId", ["vault_url", "collection", "name", "version"])


def parse_vault_id(url):
try:
parsed_uri = parse.urlparse(url)
except Exception: # pylint: disable=broad-except
raise ValueError("'{}' is not not a valid url".format(url))
if not (parsed_uri.scheme and parsed_uri.hostname):
raise ValueError("'{}' is not not a valid url".format(url))

path = list(filter(None, parsed_uri.path.split("/")))

if len(path) < 2 or len(path) > 3:
raise ValueError("'{}' is not not a valid vault url".format(url))

return _VaultId(
vault_url="{}://{}".format(parsed_uri.scheme, parsed_uri.hostname),
collection=path[0],
name=path[1],
version=path[2] if len(path) == 3 else None,
)


try:
from .async_challenge_auth_policy import AsyncChallengeAuthPolicy
from .async_client_base import AsyncKeyVaultClientBase, AsyncPagingAdapter

__all__.extend(["AsyncChallengeAuthPolicy", "AsyncKeyVaultClientBase", "AsyncPagingAdapter"])
except (SyntaxError, ImportError):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.pipeline.transport import HttpRequest, HttpResponse

from . import ChallengeAuthPolicyBase, HttpChallenge, HttpChallengeCache


class AsyncChallengeAuthPolicy(ChallengeAuthPolicyBase, AsyncHTTPPolicy):
"""policy for handling HTTP authentication challenges"""

async def send(self, request: PipelineRequest) -> HttpResponse:
challenge = HttpChallengeCache.get_challenge_for_url(request.http_request.url)
if not challenge:
# provoke a challenge with an unauthorized, bodiless request
no_body = HttpRequest(
request.http_request.method, request.http_request.url, headers=request.http_request.headers
)
if request.http_request.body:
# no_body was created with request's headers -> if request has a body, no_body's content-length is wrong
no_body.headers["Content-Length"] = "0"

challenger = await self.next.send(PipelineRequest(http_request=no_body, context=request.context))
try:
challenge = self._update_challenge(request, challenger)
except ValueError:
# didn't receive the expected challenge -> nothing more this policy can do
return challenger

await self._handle_challenge(request, challenge)
response = await self.next.send(request)

if response.http_response.status_code == 401:
# cached challenge could be outdated; maybe this response has a new one?
try:
challenge = self._update_challenge(request, response)
except ValueError:
# 401 with no legible challenge -> nothing more this policy can do
return response

await self._handle_challenge(request, challenge)
response = await self.next.send(request)

return response

async def _handle_challenge(self, request: PipelineRequest, challenge: HttpChallenge) -> None:
"""authenticate according to challenge, add Authorization header to request"""

scope = challenge.get_resource()
if not scope.endswith("/.default"):
scope += "/.default"

access_token = await self._credential.get_token(scope)
self._update_headers(request.http_request.headers, access_token.token)
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from azure.core.async_paging import AsyncPagedMixin
from azure.core.configuration import Configuration
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.pipeline.transport import AsyncioRequestsTransport, HttpTransport
from msrest.serialization import Model

from .._generated import KeyVaultClient
from .._internal import KEY_VAULT_SCOPE
from ._generated import KeyVaultClient
from . import AsyncChallengeAuthPolicy


if TYPE_CHECKING:
Expand Down Expand Up @@ -41,7 +40,7 @@ async def __anext__(self) -> Any:
# TODO: expected type Model got Coroutine instead?


class _AsyncKeyVaultClientBase:
class AsyncKeyVaultClientBase:
"""
:param credential: A credential or credential provider which can be used to authenticate to the vault,
a ValueError will be raised if the entity is not provided
Expand All @@ -58,7 +57,7 @@ def create_config(
if api_version is None:
api_version = KeyVaultClient.DEFAULT_API_VERSION
config = KeyVaultClient.get_configuration_class(api_version, aio=True)(credential, **kwargs)
config.authentication_policy = AsyncBearerTokenCredentialPolicy(credential, KEY_VAULT_SCOPE)
config.authentication_policy = AsyncChallengeAuthPolicy(credential)
return config

def __init__(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from azure.core.pipeline.transport import HttpResponse

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import HTTPPolicy
from azure.core.pipeline.policies.authentication import _BearerTokenCredentialPolicyBase
from azure.core.pipeline.transport import HttpRequest

from .http_challenge import HttpChallenge
from . import http_challenge_cache as ChallengeCache


class ChallengeAuthPolicyBase(_BearerTokenCredentialPolicyBase):
"""Sans I/O base for challenge authentication policies"""

def __init__(self, credential, **kwargs):
super(ChallengeAuthPolicyBase, self).__init__(credential, **kwargs)

@staticmethod
def _update_challenge(request, challenger):
# type: (HttpRequest, HttpResponse) -> HttpChallenge
"""parse challenge from challenger, cache it, return it"""

challenge = HttpChallenge(
request.http_request.url,
challenger.http_response.headers.get("WWW-Authenticate"),
response_headers=challenger.http_response.headers,
)
ChallengeCache.set_challenge_for_url(request.http_request.url, challenge)
return challenge


class ChallengeAuthPolicy(ChallengeAuthPolicyBase, HTTPPolicy):
"""policy for handling HTTP authentication challenges"""

def send(self, request):
# type: (PipelineRequest) -> HttpResponse

challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
if not challenge:
# provoke a challenge with an unauthorized, bodiless request
no_body = HttpRequest(
request.http_request.method, request.http_request.url, headers=request.http_request.headers
)
if request.http_request.body:
# no_body was created with request's headers -> if request has a body, no_body's content-length is wrong
no_body.headers["Content-Length"] = "0"

challenger = self.next.send(PipelineRequest(http_request=no_body, context=request.context))
try:
challenge = self._update_challenge(request, challenger)
except ValueError:
# didn't receive the expected challenge -> nothing more this policy can do
return challenger

self._handle_challenge(request, challenge)
response = self.next.send(request)

if response.http_response.status_code == 401:
# cached challenge could be outdated; maybe this response has a new one?
try:
challenge = self._update_challenge(request, response)
except ValueError:
# 401 with no legible challenge -> nothing more this policy can do
return response

self._handle_challenge(request, challenge)
response = self.next.send(request)

return response

def _handle_challenge(self, request, challenge):
# type: (PipelineRequest, HttpChallenge) -> None
"""authenticate according to challenge, add Authorization header to request"""

scope = challenge.get_resource()
if not scope.endswith("/.default"):
scope += "/.default"

access_token = self._credential.get_token(scope)
self._update_headers(request.http_request.headers, access_token.token)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from collections import namedtuple
from typing import TYPE_CHECKING
from azure.core import Configuration
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import RequestsTransport
from ._generated import KeyVaultClient

Expand All @@ -16,40 +14,13 @@
from azure.core.credentials import TokenCredential
from azure.core.pipeline.transport import HttpTransport

try:
import urllib.parse as parse
except ImportError:
import urlparse as parse # pylint: disable=import-error


_VaultId = namedtuple("VaultId", ["vault_url", "collection", "name", "version"])
from .challenge_auth_policy import ChallengeAuthPolicy


KEY_VAULT_SCOPE = "https://vault.azure.net/.default"


def _parse_vault_id(url):
try:
parsed_uri = parse.urlparse(url)
except Exception: # pylint: disable=broad-except
raise ValueError("'{}' is not not a valid url".format(url))
if not (parsed_uri.scheme and parsed_uri.hostname):
raise ValueError("'{}' is not not a valid url".format(url))

path = list(filter(None, parsed_uri.path.split("/")))

if len(path) < 2 or len(path) > 3:
raise ValueError("'{}' is not not a valid vault url".format(url))

return _VaultId(
vault_url="{}://{}".format(parsed_uri.scheme, parsed_uri.hostname),
collection=path[0],
name=path[1],
version=path[2] if len(path) == 3 else None,
)


class _KeyVaultClientBase(object):
class KeyVaultClientBase(object):
"""
:param credential: A credential or credential provider which can be used to authenticate to the vault,
a ValueError will be raised if the entity is not provided
Expand All @@ -65,7 +36,7 @@ def create_config(credential, api_version=None, **kwargs):
if api_version is None:
api_version = KeyVaultClient.DEFAULT_API_VERSION
config = KeyVaultClient.get_configuration_class(api_version, aio=False)(credential, **kwargs)
config.authentication_policy = BearerTokenCredentialPolicy(credential, KEY_VAULT_SCOPE)
config.authentication_policy = ChallengeAuthPolicy(credential)
return config

def __init__(self, vault_url, credential, config=None, transport=None, api_version=None, **kwargs):
Expand Down
Loading