Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5947e8b
get other conn types in list op
MilesHolland Jan 12, 2024
4f006da
accept empty connections in rest-local conversion
MilesHolland Jan 25, 2024
e319d40
Merge branch 'main' of https://github.com/azure/azure-sdk-for-python
MilesHolland Jan 25, 2024
e5af527
Merge branch 'main' into connections-jan-asks
MilesHolland Jan 25, 2024
9318317
azure blob base changes
MilesHolland Jan 31, 2024
d986da9
Merge branch 'main' into connections-jan-asks
MilesHolland Jan 31, 2024
1b6d921
partial replacement of if-else with maps in wc
MilesHolland Feb 2, 2024
6263dc6
more if-else removal refactors
MilesHolland Feb 5, 2024
4b6f2dc
Merge branch 'main' into connections-jan-asks
MilesHolland Feb 5, 2024
5882887
solidify connection changes
MilesHolland Feb 6, 2024
47c9b51
run black on v2
MilesHolland Feb 6, 2024
de24812
Merge branch 'main' into connections-jan-asks
MilesHolland Feb 6, 2024
5c8930f
initial self review
MilesHolland Feb 6, 2024
4b4c988
spelling
MilesHolland Feb 6, 2024
6f539fd
rerun black and ai nit
MilesHolland Feb 6, 2024
ea5b62f
more nits, and push flag to v2
MilesHolland Feb 7, 2024
7d4ca46
e2e test
MilesHolland Feb 7, 2024
d87bba7
Merge branch 'main' into connections-jan-asks
MilesHolland Feb 7, 2024
663ea4c
e2e test recordings and use 2024 restclient
MilesHolland Feb 7, 2024
a3907f9
analyze ci issues
MilesHolland Feb 7, 2024
2e85ab2
Merge branch 'main' into connections-jan-asks
MilesHolland Feb 7, 2024
5aa8207
mypy fixes
MilesHolland Feb 8, 2024
7795de7
more mypy
MilesHolland Feb 8, 2024
ec10c08
Merge branch 'connections-jan-asks' of github.com:MilesHolland/azure-…
MilesHolland Feb 8, 2024
0e2d312
remove 2024 conn dep until CI wakes up
MilesHolland Feb 8, 2024
31939b5
skip problematic test
MilesHolland Feb 8, 2024
e3be48b
try bumping mindep
MilesHolland Feb 8, 2024
710b06a
remove skip
MilesHolland Feb 8, 2024
b922eec
undo 1.14.0 dep
MilesHolland Feb 8, 2024
9d800cb
add back in 1.14 dep
MilesHolland Feb 8, 2024
a408731
comments
MilesHolland Feb 12, 2024
10922a1
linting
MilesHolland Feb 12, 2024
d2c66ed
remove extra line
MilesHolland Feb 12, 2024
dba0cbd
more linting
MilesHolland Feb 13, 2024
66e202c
run black
MilesHolland Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
initial self review
  • Loading branch information
MilesHolland committed Feb 6, 2024
commit 5c8930fa3d9cb30fbd271f088fcad96c510f75f4
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,26 @@

from typing import Any, Dict, Optional

from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml._utils.utils import camel_to_snake, _snake_to_camel
from azure.ai.ml.entities import WorkspaceConnection
from azure.ai.ml.entities._credentials import ApiKeyConfiguration
from azure.core.credentials import TokenCredential
from azure.ai.ml._restclient.v2023_06_01_preview.models import ConnectionCategory
from azure.ai.ml.constants._common import (
WorkspaceConnectionTypes,
)

class BaseConnection:
"""A connection to a another system or service. This is a base class and should not be
instantiated directly. Use the child classes that are specialized for connections to different services.
"""A connection to a another system or service. This is a base class and should only be directly
used to cover the simplest of connection types that don't merit their own subclass.

:param name: The name of the connection
:type name: str
:param target: The URL or ARM resource ID of the external resource.
:type target: str
:param type: The type of connection this represents. Acceptable values include:
"azure_open_ai", "cognitive_service", and "cognitive_search".
"azure_open_ai", "cognitive_service", "cognitive_search",
"git", "azure_blob", and "custom".
:type type: str
:param credentials: The credentials for authenticating to the external resource.
:type credentials: ~azure.ai.ml.entities.ApiKeyConfiguration
Expand Down Expand Up @@ -52,6 +54,8 @@ def __init__(
if kwargs.pop("make_empty", False):
return

# Dev note: this is an important line, as is quietly changes the internal object from
# A WorkspaceConnection to the relevant subclass if needed.
conn_class = WorkspaceConnection._get_entity_class_from_type(type)
self._workspace_connection = conn_class(
target=target,
Expand All @@ -71,9 +75,9 @@ def _from_v2_workspace_connection(cls, workspace_connection: WorkspaceConnection
:return: The converted connection.
:rtype: ~azure.ai.resources.entities.Connection
"""
# It's simpler to create a placeholder connection, then overwrite the internal WC.
# We don't need to worry about the potentially changing WC fields this way.
conn_class = cls._get_ai_connection_class_from_type(workspace_connection.type)
# It's simpler to create a placeholder connection, then overwrite it..
# We don't need to worry about the accidentally changing WC fields this way.
conn = conn_class(type="a", target="a", credentials=None, name="a", make_empty=True, api_version=None, api_type=None, kind=None)
conn._workspace_connection = workspace_connection
return conn
Expand All @@ -83,6 +87,10 @@ def _get_ai_connection_class_from_type(cls, conn_type: str):
"""Given a connection type string, get the corresponding AI SDK object via class
comparisons. Accounts for any potential camel/snake/capitalization issues. Returns
the BaseConnection class if no match is found.

:param conn_type: The type of the connection as a string. Should match one of the
supported connection category values.
:type conn_type: str
"""
#import here to avoid circular import
from .connection_subtypes import (
Expand All @@ -94,22 +102,23 @@ def _get_ai_connection_class_from_type(cls, conn_type: str):
AzureBlobStoreConnection
)

cat = camel_to_snake(conn_type).lower()
if cat == camel_to_snake(ConnectionCategory.AZURE_OPEN_AI).lower():
return AzureOpenAIConnection
elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SEARCH).lower():
return AzureAISearchConnection
elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SERVICE).lower():
return AzureAIServiceConnection
elif cat == camel_to_snake(ConnectionCategory.GIT).lower():
return GitHubConnection
elif cat == camel_to_snake("azure_blob").lower():
return AzureBlobStoreConnection
elif (cat == camel_to_snake(ConnectionCategory.CUSTOM_KEYS).lower()
or cat == camel_to_snake(WorkspaceConnectionTypes.CUSTOM).lower()):
# Accept both custom_keys and custom as conversion type in case of
# improper input.
return CustomConnection
# Custom accounts for both "custom_keys" and "custom" as conversion types in case of
# improper input.
# All values are lower-cased to normalize input.
# TODO replace azure blob with conn category when available.
CONNECTION_CATEGORY_TO_SUBCLASS_MAP = {
ConnectionCategory.AZURE_OPEN_AI.lower(): AzureOpenAIConnection,
ConnectionCategory.COGNITIVE_SEARCH.lower(): AzureAISearchConnection,
ConnectionCategory.COGNITIVE_SERVICE.lower(): AzureAIServiceConnection,
ConnectionCategory.GIT.lower(): GitHubConnection,
"azure_blob".lower(): AzureBlobStoreConnection,
ConnectionCategory.CUSTOM_KEYS.lower(): CustomConnection,
WorkspaceConnectionTypes.CUSTOM.lower(): CustomConnection
}
# snake and lower-case inputted type to match the map.
cat = _snake_to_camel(conn_type).lower()
if cat in CONNECTION_CATEGORY_TO_SUBCLASS_MAP:
return CONNECTION_CATEGORY_TO_SUBCLASS_MAP[cat]
return BaseConnection


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,14 @@ def __init__(
super().__init__(target=target, type="git", credentials=credentials, **kwargs)

class AzureBlobStoreConnection(BaseConnection):
"""A Connection to an Azure Blob Datastore.
"""A Connection to an Azure Blob Datastore. Unlike other connection objects, this subclass has no credentials.
NOTE: This connection type is currently READ-ONLY via the LIST operation. Attempts to create or update
a connection of this type will result in an error.

:param name: Name of the connection.
:type name: str
:param target: The URL or ARM resource ID of the resource.
:type target: str
:param credentials: ?????
:type credentials: ?????
:param is_shared: For connections created for a project, this determines if the connection
is shared amongst other connections with that project's parent AI resource. Defaults to True.
:type is_shared: bool
Expand All @@ -338,7 +336,6 @@ def __init__(
self,
*,
target: str,
credentials = None,
container_name: str,
account_name: str,
**kwargs,
Expand All @@ -347,7 +344,6 @@ def __init__(
super().__init__(
target=target,
type="azure_blob",
credentials=credentials,
container_name=container_name,
account_name=account_name,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23790,7 +23790,7 @@ def __init__(
self.auth_type = 'None' # type: str


class NoneDatastoreCredentials(DatastoreCredentials):
class NoneDatastoreNoneAuthTypeWorkspaceConnectionPropertiesentials(DatastoreCredentials):
"""Empty/none datastore credentials.

All required parameters must be populated in order to send to Azure.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

# pylint: disable=unused-argument

##### DEV NOTE: For some reason, these schemas correlate to the classes defined in ~azure.ai.ml.entities._credentials.
# There used to be a credentials.py file in ~azure.ai.ml.entities.workspace.connections, but it was, as far as I could tell,
# never used. So I removed it and added this comment.


from typing import Dict

from marshmallow import fields, post_load
Expand Down
3 changes: 2 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,8 @@ class InputOutputModes:


class WorkspaceConnectionTypes:
"""Names for workspace connection types that are different from that underlying api enum values."""
"""Names for workspace connection types that are different from that underlying api enum values
from the ConnectionCategory class."""

CUSTOM = "custom" # Corresponds to "custom_keys".

Expand Down
64 changes: 61 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# pylint: disable=protected-access,redefined-builtin

from abc import ABC
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Type

from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata
from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration
Expand Down Expand Up @@ -53,21 +53,54 @@
from azure.ai.ml._restclient.v2023_04_01_preview.models import (
WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
)
from azure.ai.ml._restclient.v2023_06_01_preview.models import ConnectionAuthType
from azure.ai.ml._restclient.v2023_06_01_preview.models import (
WorkspaceConnectionApiKey as RestWorkspaceConnectionApiKey,
)
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal, _snake_to_camel
from azure.ai.ml.constants._common import CommonYamlFields, IdentityType
from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, ValidationErrorType, ValidationException

from azure.ai.ml._restclient.v2023_08_01_preview.models import (
ConnectionAuthType,
AccessKeyAuthTypeWorkspaceConnectionProperties,
ApiKeyAuthWorkspaceConnectionProperties,
ManagedIdentityAuthTypeWorkspaceConnectionProperties,
NoneAuthTypeWorkspaceConnectionProperties,
PATAuthTypeWorkspaceConnectionProperties,
SASAuthTypeWorkspaceConnectionProperties,
ServicePrincipalAuthTypeWorkspaceConnectionProperties,
UsernamePasswordAuthTypeWorkspaceConnectionProperties,
)

class _BaseIdentityConfiguration(ABC, DictMixin, RestTranslatableMixin):
def __init__(self) -> None:
self.type: Any = None

@classmethod
def _get_credential_class_from_rest_type(cls, auth_type: str) -> Type:
# Defined in this file instead of in constants file to avoid adding imports there and risking
# circular imports. This map links rest enums to the corresponding client classes.
# Enums are all lower-cased because rest enums aren't always consistent with their
# camel casing rules.
# Defined in this class because I didn't want this at the bottom of the file,
# but the classes aren't visible to the interpreter at the start of the file.
# Technically most of these classes aren't child of _BaseIdentityConfiguration, but
# I don't care.
REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP = {
ConnectionAuthType.SAS.lower(): SasTokenConfiguration,
ConnectionAuthType.PAT.lower(): PatTokenConfiguration,
ConnectionAuthType.ACCESS_KEY.lower(): AccessKeyConfiguration,
ConnectionAuthType.USERNAME_PASSWORD.lower(): UsernamePasswordConfiguration,
ConnectionAuthType.SERVICE_PRINCIPAL.lower(): ServicePrincipalConfiguration,
ConnectionAuthType.MANAGED_IDENTITY.lower(): ManagedIdentityConfiguration,
ConnectionAuthType.API_KEY.lower(): ApiKeyConfiguration,
}
if not auth_type:
return NoneCredentialConfiguration
return REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP.get(_snake_to_camel(auth_type).lower(), NoneCredentialConfiguration)


class AccountKeyConfiguration(RestTranslatableMixin, DictMixin):
def __init__(
Expand Down Expand Up @@ -129,6 +162,10 @@ def __eq__(self, other: object) -> bool:

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

@classmethod
def _get_rest_properties_class(cls) -> Type:
return SASAuthTypeWorkspaceConnectionProperties


class PatTokenConfiguration(RestTranslatableMixin, DictMixin):
Expand Down Expand Up @@ -166,6 +203,9 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return self.pat == other.pat

@classmethod
def _get_rest_properties_class(cls) -> Type:
return PATAuthTypeWorkspaceConnectionProperties

class UsernamePasswordConfiguration(RestTranslatableMixin, DictMixin):
"""Username and password credentials.
Expand Down Expand Up @@ -204,6 +244,9 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return self.username == other.username and self.password == other.password

@classmethod
def _get_rest_properties_class(cls) -> Type:
return UsernamePasswordAuthTypeWorkspaceConnectionProperties

class BaseTenantCredentials(RestTranslatableMixin, DictMixin, ABC):
"""Base class for tenant credentials.
Expand Down Expand Up @@ -307,6 +350,9 @@ def __eq__(self, other: object) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

@classmethod
def _get_rest_properties_class(cls) -> Type:
return ServicePrincipalAuthTypeWorkspaceConnectionProperties

class CertificateConfiguration(BaseTenantCredentials):
def __init__(
Expand Down Expand Up @@ -524,6 +570,9 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return self.client_id == other.client_id and self.resource_id == other.resource_id

@classmethod
def _get_rest_properties_class(cls) -> Type:
return ManagedIdentityAuthTypeWorkspaceConnectionProperties

class UserIdentityConfiguration(_BaseIdentityConfiguration):
"""User identity configuration.
Expand Down Expand Up @@ -763,6 +812,9 @@ def __eq__(self, other: object) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

@classmethod
def _get_rest_properties_class(cls) -> Type:
return NoneAuthTypeWorkspaceConnectionProperties

class AccessKeyConfiguration(RestTranslatableMixin, DictMixin):
"""Access Key Credentials.
Expand Down Expand Up @@ -802,6 +854,9 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, AccessKeyConfiguration):
return NotImplemented
return self.access_key_id == other.access_key_id and self.secret_access_key == other.secret_access_key

def _get_rest_properties_class(self):
return AccessKeyAuthTypeWorkspaceConnectionProperties


@experimental
Expand Down Expand Up @@ -838,3 +893,6 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, ApiKeyConfiguration):
return NotImplemented
return bool(self.key == other.key)

def _get_rest_properties_class(self):
return ApiKeyAuthWorkspaceConnectionProperties
Loading