diff --git a/sdk/ai/azure-ai-resources/CHANGELOG.md b/sdk/ai/azure-ai-resources/CHANGELOG.md index b35f7349533c..5f786e114293 100644 --- a/sdk/ai/azure-ai-resources/CHANGELOG.md +++ b/sdk/ai/azure-ai-resources/CHANGELOG.md @@ -3,10 +3,13 @@ ## 1.0.0b7 (Unreleased) ### Features Added +- Connections LIST operation now supports returning data connections via new optional flag: include_data_connections. +- Connections have read-only support for 3 new connection types: gen 2, data lake, and azure blob. ### Breaking Changes ### Bugs Fixed +- Connections docstrings now discuss scope field. ### Other Changes diff --git a/sdk/ai/azure-ai-resources/azure/ai/resources/entities/__init__.py b/sdk/ai/azure-ai-resources/azure/ai/resources/entities/__init__.py index 87e168443003..15c8921fb9aa 100644 --- a/sdk/ai/azure-ai-resources/azure/ai/resources/entities/__init__.py +++ b/sdk/ai/azure-ai-resources/azure/ai/resources/entities/__init__.py @@ -13,11 +13,13 @@ AzureAIServiceConnection, GitHubConnection, CustomConnection, + AzureBlobStoreConnection, ) from .mlindex import Index from .project import Project from .data import Data from .configs import AzureOpenAIModelConfiguration +from azure.ai.ml.entities import ApiKeyConfiguration __all__ = [ "BaseConnection", @@ -31,5 +33,7 @@ "AzureOpenAIModelConfiguration", "GitHubConnection", "CustomConnection", + "AzureBlobStoreConnection", + "ApiKeyConfiguration" ] diff --git a/sdk/ai/azure-ai-resources/azure/ai/resources/entities/base_connection.py b/sdk/ai/azure-ai-resources/azure/ai/resources/entities/base_connection.py index 54c215ceda67..3f1c84120a4b 100644 --- a/sdk/ai/azure-ai-resources/azure/ai/resources/entities/base_connection.py +++ b/sdk/ai/azure-ai-resources/azure/ai/resources/entities/base_connection.py @@ -5,23 +5,26 @@ from typing import Any, Dict, Optional -from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml._utils.utils import camel_to_snake, _snake_to_camel from azure.ai.ml.entities import WorkspaceConnection from azure.ai.ml.entities._credentials import ApiKeyConfiguration from azure.core.credentials import TokenCredential -from azure.ai.ml._restclient.v2023_06_01_preview.models import ConnectionCategory - +from azure.ai.ml._restclient.v2023_08_01_preview.models import ConnectionCategory +from azure.ai.ml.constants._common import ( + WorkspaceConnectionTypes, +) class BaseConnection: - """A connection to a specific external AI system. This is a base class and should not be - instantiated directly. Use the child classes that are specialized for connections to different services. + """A connection to a another system or service. This is a base class and should only be directly + used to cover the simplest of connection types that don't merit their own subclass. :param name: The name of the connection :type name: str :param target: The URL or ARM resource ID of the external resource. :type target: str :param type: The type of connection this represents. Acceptable values include: - "azure_open_ai", "cognitive_service", and "cognitive_search". + "azure_open_ai", "cognitive_service", "cognitive_search", + "git", "azure_blob", and "custom". :type type: str :param credentials: The credentials for authenticating to the external resource. :type credentials: ~azure.ai.ml.entities.ApiKeyConfiguration @@ -46,10 +49,13 @@ def __init__( **kwargs, ): # Sneaky short-circuit to allow direct v2 WS conn injection without any - # polymorphic required argument hassles. + # polymorphic-argument hassles. + # See _from_v2_workspace_connection for usage. if kwargs.pop("make_empty", False): return + # Dev note: this is an important line, as is quietly changes the internal object from + # A WorkspaceConnection to the relevant subclass if needed. conn_class = WorkspaceConnection._get_entity_class_from_type(type) self._workspace_connection = conn_class( target=target, @@ -69,10 +75,23 @@ def _from_v2_workspace_connection(cls, workspace_connection: WorkspaceConnection :return: The converted connection. :rtype: ~azure.ai.resources.entities.Connection """ - # It's simpler to create a placeholder connection, then overwrite the internal WC. - # We don't need to worry about the potentially changing WC fields this way. conn_class = cls._get_ai_connection_class_from_type(workspace_connection.type) - conn = conn_class(type="a", target="a", credentials=None, name="a", make_empty=True, api_version=None, api_type=None, kind=None) + # This slightly-cheeky init is just a placeholder that is immediately replaced + # with a directly-injected v2 connection object. + # Since all connection class initializers have kwargs, we can just throw a kitchen sink's + # worth of inputs to satisfy all possible subclass input requirements. + conn = conn_class( + type="a", + target="a", + credentials=None, + name="a", + make_empty=True, + api_version=None, + api_type=None, + kind=None, + account_name="a", + container_name="a", + ) conn._workspace_connection = workspace_connection return conn @@ -81,21 +100,38 @@ def _get_ai_connection_class_from_type(cls, conn_type: str): """Given a connection type string, get the corresponding AI SDK object via class comparisons. Accounts for any potential camel/snake/capitalization issues. Returns the BaseConnection class if no match is found. + + :param conn_type: The type of the connection as a string. Should match one of the + supported connection category values. + :type conn_type: str """ #import here to avoid circular import from .connection_subtypes import ( AzureOpenAIConnection, AzureAISearchConnection, AzureAIServiceConnection, + GitHubConnection, + CustomConnection, + AzureBlobStoreConnection ) - cat = camel_to_snake(conn_type).lower() - if cat == camel_to_snake(ConnectionCategory.AZURE_OPEN_AI).lower(): - return AzureOpenAIConnection - elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SEARCH).lower(): - return AzureAISearchConnection - elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SERVICE).lower(): - return AzureAIServiceConnection + # Custom accounts for both "custom_keys" and "custom" as conversion types in case of + # improper input. + # All values are lower-cased to normalize input. + # TODO replace azure blob with conn category when available. + CONNECTION_CATEGORY_TO_SUBCLASS_MAP = { + ConnectionCategory.AZURE_OPEN_AI.lower(): AzureOpenAIConnection, + ConnectionCategory.COGNITIVE_SEARCH.lower(): AzureAISearchConnection, + ConnectionCategory.COGNITIVE_SERVICE.lower(): AzureAIServiceConnection, + ConnectionCategory.GIT.lower(): GitHubConnection, + "azureblob": AzureBlobStoreConnection, # TODO replace with conn category when available + ConnectionCategory.CUSTOM_KEYS.lower(): CustomConnection, + WorkspaceConnectionTypes.CUSTOM.lower(): CustomConnection + } + # snake and lower-case inputted type to match the map. + cat = _snake_to_camel(conn_type).lower() + if cat in CONNECTION_CATEGORY_TO_SUBCLASS_MAP: + return CONNECTION_CATEGORY_TO_SUBCLASS_MAP[cat] return BaseConnection diff --git a/sdk/ai/azure-ai-resources/azure/ai/resources/entities/connection_subtypes.py b/sdk/ai/azure-ai-resources/azure/ai/resources/entities/connection_subtypes.py index b98c1c23eb1f..4bcbfb45c346 100644 --- a/sdk/ai/azure-ai-resources/azure/ai/resources/entities/connection_subtypes.py +++ b/sdk/ai/azure-ai-resources/azure/ai/resources/entities/connection_subtypes.py @@ -12,6 +12,8 @@ CONNECTION_API_VERSION_KEY, CONNECTION_API_TYPE_KEY, CONNECTION_KIND_KEY, + CONNECTION_ACCOUNT_NAME_KEY, + CONNECTION_CONTAINER_NAME_KEY, ) from .base_connection import BaseConnection @@ -25,8 +27,6 @@ class AzureOpenAIConnection(BaseConnection): :type name: str :param target: The URL or ARM resource ID of the external resource. :type target: str - :param tags: Tag dictionary. Tags can be added, removed, and updated. - :type tags: dict :param credentials: The credentials for authenticating the external resource. :type credentials: ~azure.ai.ml.entities.ApiKeyConfiguration :param api_version: The api version that this connection was created for. @@ -36,6 +36,8 @@ class AzureOpenAIConnection(BaseConnection): :param is_shared: For connections created for a project, this determines if the connection is shared amongst other connections with that project's parent AI resource. Defaults to True. :type is_shared: bool + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict """ def __init__( @@ -66,8 +68,8 @@ def api_version(self) -> Optional[str]: def api_version(self, value: str) -> None: """Set the API version of the connection. - :return: the API version of the connection. - :rtype: str + :param value: the API version of the connection. + :type value: str """ self._workspace_connection.tags[CONNECTION_API_VERSION_KEY] = value @@ -86,8 +88,8 @@ def api_type(self) -> Optional[str]: def api_type(self, value: str) -> None: """Set the API type of the connection. - :return: the API type of the connection. - :rtype: str + :param value: the API type of the connection. + :type value: str """ self._workspace_connection.tags[CONNECTION_API_TYPE_KEY] = value @@ -154,8 +156,6 @@ class AzureAISearchConnection(BaseConnection): :type name: str :param target: The URL or ARM resource ID of the external resource. :type target: str - :param tags: Tag dictionary. Tags can be added, removed, and updated. - :type tags: dict :param credentials: The credentials for authenticating the external resource. :type credentials: ~azure.ai.ml.entities.ApiKeyConfiguration :param api_version: The api version that this connection was created for. Only applies to certain connection types. @@ -163,6 +163,8 @@ class AzureAISearchConnection(BaseConnection): :param is_shared: For connections created for a project, this determines if the connection is shared amongst other connections with that project's parent AI resource. Defaults to True. :type is_shared: bool + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict """ def __init__( @@ -192,8 +194,8 @@ def api_version(self) -> Optional[str]: def api_version(self, value: str) -> None: """Set the API version of the connection. - :return: the API version of the connection. - :rtype: str + :param value: the API version of the connection. + :type value: str """ self._workspace_connection.tags[CONNECTION_API_VERSION_KEY] = value @@ -224,8 +226,6 @@ class AzureAIServiceConnection(BaseConnection): :type name: str :param target: The URL or ARM resource ID of the external resource. :type target: str - :param tags: Tag dictionary. Tags can be added, removed, and updated. - :type tags: dict :param credentials: The credentials for authenticating the external resource. :type credentials: ~azure.ai.ml.entities.ApiKeyConfiguration :param api_version: The api version that this connection was created for. @@ -236,6 +236,8 @@ class AzureAIServiceConnection(BaseConnection): :param is_shared: For connections created for a project, this determines if the connection is shared amongst other connections with that project's parent AI resource. Defaults to True. :type is_shared: bool + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict """ def __init__( @@ -265,8 +267,8 @@ def api_version(self) -> Optional[str]: def api_version(self, value: str) -> None: """Set the API version of the connection. - :return: the API version of the connection. - :rtype: str + :param value: the API version of the connection. + :type value: str """ self._workspace_connection.tags[CONNECTION_API_VERSION_KEY] = value @@ -285,8 +287,8 @@ def kind(self) -> Optional[str]: def kind(self, value: str) -> None: """Set the kind of the connection. - :return: the kind of the connection. - :rtype: str + :param value: the kind of the connection. + :type value: str """ self._workspace_connection.tags[CONNECTION_KIND_KEY] = value @@ -297,13 +299,13 @@ class GitHubConnection(BaseConnection): :type name: str :param target: The URL or ARM resource ID of the external resource. :type target: str - :param tags: Tag dictionary. Tags can be added, removed, and updated. - :type tags: dict :param credentials: The credentials for authenticating the external resource. :type credentials: ~azure.ai.ml.entities.ApiKeyConfiguration :param is_shared: For connections created for a project, this determines if the connection is shared amongst other connections with that project's parent AI resource. Defaults to True. :type is_shared: bool + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict """ def __init__( @@ -316,6 +318,82 @@ def __init__( kwargs.pop("type", None) # make sure we never somehow use wrong type super().__init__(target=target, type="git", credentials=credentials, **kwargs) +class AzureBlobStoreConnection(BaseConnection): + """A Connection to an Azure Blob Datastore. Unlike other connection objects, this subclass has no credentials. + NOTE: This connection type is currently READ-ONLY via the LIST operation. Attempts to create or update + a connection of this type will result in an error. + + :param name: Name of the connection. + :type name: str + :param target: The URL or ARM resource ID of the resource. + :type target: str + :param container_name: The container name of the connection. + :type container_name: str + :param account_name: The account name of the connection. + :type account_name: str + :param is_shared: For connections created for a project, this determines if the connection + is shared amongst other connections with that project's parent AI resource. Defaults to True. + :type is_shared: bool + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict + """ + + def __init__( + self, + *, + target: str, + container_name: str, + account_name: str, + **kwargs, + ) -> None: + kwargs.pop("type", None) # make sure we never somehow use wrong type + super().__init__( + target=target, + type="azure_blob", + container_name=container_name, + account_name=account_name, + **kwargs + ) + + @property + def container_name(self) -> Optional[str]: + """The container name of the connection. + + :return: the container name of the connection. + :rtype: Optional[str] + """ + if self._workspace_connection.tags is not None: + return self._workspace_connection.tags.get(CONNECTION_CONTAINER_NAME_KEY, None) + return None + + @container_name.setter + def container_name(self, value: str) -> None: + """Set the container name of the connection. + + :param value: the new container name of the connection. + :type value: str + """ + self._workspace_connection.tags[CONNECTION_CONTAINER_NAME_KEY] = value + + @property + def account_name(self) -> Optional[str]: + """The account name of the connection. + + :return: the account name of the connection. + :rtype: Optional[str] + """ + if self._workspace_connection.tags is not None: + return self._workspace_connection.tags.get(CONNECTION_ACCOUNT_NAME_KEY, None) + return None + + @account_name.setter + def account_name(self, value: str) -> None: + """Set the account name of the connection. + + :param value: the new account name of the connection. + :type value: str + """ + self._workspace_connection.tags[CONNECTION_ACCOUNT_NAME_KEY] = value class CustomConnection(BaseConnection): """A Connection to system that's not encapsulated by other connection types. @@ -323,13 +401,13 @@ class CustomConnection(BaseConnection): :type name: str :param target: The URL or ARM resource ID of the external resource. :type target: str - :param tags: Tag dictionary. Tags can be added, removed, and updated. - :type tags: dict :param credentials: The credentials for authenticating the external resource. :type credentials: ~azure.ai.ml.entities.ApiKeyConfiguration :param is_shared: For connections created for a project, this determines if the connection is shared amongst other connections with that project's parent AI resource. Defaults to True. :type is_shared: bool + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict """ def __init__( self, diff --git a/sdk/ai/azure-ai-resources/azure/ai/resources/operations/_connection_operations.py b/sdk/ai/azure-ai-resources/azure/ai/resources/operations/_connection_operations.py index 1d8853cef932..59416edc46e0 100644 --- a/sdk/ai/azure-ai-resources/azure/ai/resources/operations/_connection_operations.py +++ b/sdk/ai/azure-ai-resources/azure/ai/resources/operations/_connection_operations.py @@ -30,20 +30,34 @@ def __init__(self, *, resource_ml_client: MLClient = None, project_ml_client: ML @distributed_trace @monitor_with_activity(logger, "Connection.List", ActivityType.PUBLICAPI) - def list(self, connection_type: Optional[str] = None, scope: str = OperationScope.AI_RESOURCE) -> Iterable[BaseConnection]: + def list( + self, + connection_type: Optional[str] = None, + scope: str = OperationScope.AI_RESOURCE, + include_data_connections: bool = False, + ) -> Iterable[BaseConnection]: """List all connection assets in a project. :param connection_type: If set, return only connections of the specified type. :type connection_type: str + :param scope: The scope of the operation, which determines if the operation will list all connections + that are available to the AI Client's AI Resource or just those available to the project. + Defaults to AI resource-level scoping. + :type scope: ~azure.ai.resources.constants.OperationScope + :param include_data_connections: If true, also return data connections. Defaults to False. + :type include_data_connections: bool - :return: An iterator like instance of connection objects + :return: An iterator of connection objects :rtype: Iterable[Connection] """ client = self._resource_ml_client if scope == OperationScope.AI_RESOURCE else self._project_ml_client - return [ - BaseConnection._from_v2_workspace_connection(conn) - for conn in client._workspace_connections.list(connection_type=connection_type) - ] + + operation_result = client._workspace_connections.list( + connection_type=connection_type, + include_data_connections=include_data_connections + ) + + return [BaseConnection._from_v2_workspace_connection(conn) for conn in operation_result] @distributed_trace @monitor_with_activity(logger, "Connection.Get", ActivityType.PUBLICAPI) @@ -52,6 +66,10 @@ def get(self, name: str, scope: str = OperationScope.AI_RESOURCE, **kwargs) -> B :param name: Name of the connection. :type name: str + :param scope: The scope of the operation, which determines if the operation will search among + all connections that are available to the AI Client's AI Resource or just those available to the project. + Defaults to AI resource-level scoping. + :type scope: ~azure.ai.resources.constants.OperationScope :return: The connection with the provided name. :rtype: Connection @@ -62,7 +80,7 @@ def get(self, name: str, scope: str = OperationScope.AI_RESOURCE, **kwargs) -> B # It's by design that both API and V2 SDK don't include the secrets from API response, the following # code fills the gap when possible - if not connection.credentials.key: + if not (connection.credentials and connection.credentials.key): list_secrets_response = client.connections._operation.list_secrets( connection_name=name, resource_group_name=client.resource_group_name, @@ -100,6 +118,10 @@ def delete(self, name: str, scope: str = OperationScope.AI_RESOURCE) -> None: :param name: Name of the connection to delete. :type name: str + :param scope: The scope of the operation, which determines if the operation should search amongst + the connections available to the AI Client's AI Resource for the target connection, or through + the connections available to the project. Defaults to AI resource-level scoping. + :type scope: ~azure.ai.resources.constants.OperationScope """ client = self._resource_ml_client if scope == OperationScope.AI_RESOURCE else self._project_ml_client return client._workspace_connections.delete(name=name) diff --git a/sdk/ai/azure-ai-resources/setup.py b/sdk/ai/azure-ai-resources/setup.py index ef641d1b84d5..5d4bcf3261fb 100644 --- a/sdk/ai/azure-ai-resources/setup.py +++ b/sdk/ai/azure-ai-resources/setup.py @@ -67,7 +67,7 @@ python_requires="<4.0,>=3.8", install_requires=[ # NOTE: To avoid breaking changes in a major version bump, all dependencies should pin an upper bound if possible. - "azure-ai-ml>=1.13.0", + "azure-ai-ml>=1.14.0", "mlflow-skinny<3", "opencensus-ext-logging<=0.1.1", "azure-mgmt-resource<23.0.0,>=22.0.0", diff --git a/sdk/ml/azure-ai-ml/CHANGELOG.md b/sdk/ml/azure-ai-ml/CHANGELOG.md index b6798f3a4c38..062460aed6bc 100644 --- a/sdk/ml/azure-ai-ml/CHANGELOG.md +++ b/sdk/ml/azure-ai-ml/CHANGELOG.md @@ -4,6 +4,9 @@ ### Features Added - Remove `experimental` tag for `ml_client.jobs.validate`. +- Workspace Connection has new read-only subclass: AzureBlobStoreWorkspaceConnectionSchema. +- Workspace Connection supports 2 new types under main class: gen 2 and azure_one_lake. +- Workspace Connection LIST operation can return data connections via new optional flag: include_data_connections. ### Bugs Fixed - Fix pipeline job `outputs` not load correctly when `component: ` exists in pipeline job yaml. diff --git a/sdk/ml/azure-ai-ml/assets.json b/sdk/ml/azure-ai-ml/assets.json index bfbfa8146344..e62f66b528d9 100644 --- a/sdk/ml/azure-ai-ml/assets.json +++ b/sdk/ml/azure-ai-ml/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/ml/azure-ai-ml", - "Tag": "python/ml/azure-ai-ml_dc05917b7a" + "Tag": "python/ml/azure-ai-ml_67cc2fc9de" } diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/__init__.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/__init__.py index 95df393e3b52..b61f012d0414 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/__init__.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/__init__.py @@ -9,6 +9,7 @@ OpenAIWorkspaceConnectionSchema, AzureAISearchWorkspaceConnectionSchema, AzureAIServiceWorkspaceConnectionSchema, + AzureBlobStoreWorkspaceConnectionSchema, ) __all__ = [ @@ -16,4 +17,5 @@ "OpenAIWorkspaceConnectionSchema", "AzureAISearchWorkspaceConnectionSchema", "AzureAIServiceWorkspaceConnectionSchema", + "AzureBlobStoreWorkspaceConnectionSchema", ] diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/credentials.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/credentials.py index 39591a96d602..e343b0ca8b72 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/credentials.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/credentials.py @@ -4,6 +4,10 @@ # pylint: disable=unused-argument +##### DEV NOTE: For some reason, these schemas correlate to the classes defined in ~azure.ai.ml.entities._credentials. +# There used to be a credentials.py file in ~azure.ai.ml.entities.workspace.connections, +# but it was, as far as I could tell, never used. So I removed it and added this comment. + from typing import Dict from marshmallow import fields, post_load diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/workspace_connection_subtypes.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/workspace_connection_subtypes.py index 162d026620e5..5d844077f966 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/workspace_connection_subtypes.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/workspace_connection_subtypes.py @@ -6,7 +6,7 @@ from marshmallow import fields, post_load -from azure.ai.ml._restclient.v2023_06_01_preview.models import ConnectionCategory +from azure.ai.ml._restclient.v2024_01_01_preview.models import ConnectionCategory from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum from azure.ai.ml._schema.workspace.connections.credentials import ApiKeyConfigurationSchema from azure.ai.ml._utils.utils import camel_to_snake @@ -62,3 +62,21 @@ def make(self, data, **kwargs): from azure.ai.ml.entities import AzureAIServiceWorkspaceConnection return AzureAIServiceWorkspaceConnection(**data) + + +# pylint: disable-next=name-too-long +class AzureBlobStoreWorkspaceConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( # TODO replace with real connection category. + allowed_values=ConnectionCategory.AZURE_BLOB, casing_transform=camel_to_snake, required=True + ) + credentials = NestedField(ApiKeyConfigurationSchema) + + account_name = fields.Str(required=True, allow_none=False) + container_name = fields.Str(required=True, allow_none=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AzureBlobStoreWorkspaceConnection + + return AzureBlobStoreWorkspaceConnection(**data) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_version.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_version.py index de751c672e6e..c14645fdfb4a 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_version.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_version.py @@ -2,4 +2,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -VERSION = "1.13.0" +VERSION = "1.14.0" diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py index 9045ea33fe8d..c8a785b35cd9 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py @@ -186,6 +186,8 @@ CONNECTION_API_VERSION_KEY = "ApiVersion" CONNECTION_API_TYPE_KEY = "ApiType" CONNECTION_KIND_KEY = "Kind" +CONNECTION_CONTAINER_NAME_KEY = "ContainerName" +CONNECTION_ACCOUNT_NAME_KEY = "AccountName" class DefaultOpenEncoding: @@ -791,9 +793,10 @@ class InputOutputModes: class WorkspaceConnectionTypes: - """Names for workspace connection types that are different from that underlying api enum values.""" + """Names for workspace connection types that are different from that underlying api enum values + from the ConnectionCategory class.""" - CUSTOM = "custom" + CUSTOM = "custom" # Corresponds to "custom_keys". class LegacyAssetTypes: diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py index 59b5b05f37d9..b0bcfc0881ac 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py @@ -205,6 +205,7 @@ AzureAISearchWorkspaceConnection, AzureAIServiceWorkspaceConnection, AzureOpenAIWorkspaceConnection, + AzureBlobStoreWorkspaceConnection, ) from ._workspace.customer_managed_key import CustomerManagedKey from ._workspace.diagnose import ( @@ -285,6 +286,7 @@ "WorkspaceKeys", "WorkspaceConnection", "AzureOpenAIWorkspaceConnection", + "AzureBlobStoreWorkspaceConnection", "AzureAISearchWorkspaceConnection", "AzureAIServiceWorkspaceConnection", "DiagnoseRequestProperties", diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py index 8b42547ea8ac..54a8a0ff53d5 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py @@ -5,7 +5,7 @@ # pylint: disable=protected-access,redefined-builtin from abc import ABC -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Type from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration @@ -53,21 +53,57 @@ from azure.ai.ml._restclient.v2023_04_01_preview.models import ( WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey, ) -from azure.ai.ml._restclient.v2023_06_01_preview.models import ConnectionAuthType from azure.ai.ml._restclient.v2023_06_01_preview.models import ( WorkspaceConnectionApiKey as RestWorkspaceConnectionApiKey, ) from azure.ai.ml._utils._experimental import experimental -from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal, _snake_to_camel from azure.ai.ml.constants._common import CommonYamlFields, IdentityType from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, ValidationErrorType, ValidationException +from azure.ai.ml._restclient.v2023_08_01_preview.models import ( + ConnectionAuthType, + AccessKeyAuthTypeWorkspaceConnectionProperties, + ApiKeyAuthWorkspaceConnectionProperties, + ManagedIdentityAuthTypeWorkspaceConnectionProperties, + NoneAuthTypeWorkspaceConnectionProperties, + PATAuthTypeWorkspaceConnectionProperties, + SASAuthTypeWorkspaceConnectionProperties, + ServicePrincipalAuthTypeWorkspaceConnectionProperties, + UsernamePasswordAuthTypeWorkspaceConnectionProperties, +) + class _BaseIdentityConfiguration(ABC, DictMixin, RestTranslatableMixin): def __init__(self) -> None: self.type: Any = None + @classmethod + def _get_credential_class_from_rest_type(cls, auth_type: str) -> Type: + # Defined in this file instead of in constants file to avoid adding imports there and risking + # circular imports. This map links rest enums to the corresponding client classes. + # Enums are all lower-cased because rest enums aren't always consistent with their + # camel casing rules. + # Defined in this class because I didn't want this at the bottom of the file, + # but the classes aren't visible to the interpreter at the start of the file. + # Technically most of these classes aren't child of _BaseIdentityConfiguration, but + # I don't care. + REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP = { + ConnectionAuthType.SAS.lower(): SasTokenConfiguration, + ConnectionAuthType.PAT.lower(): PatTokenConfiguration, + ConnectionAuthType.ACCESS_KEY.lower(): AccessKeyConfiguration, + ConnectionAuthType.USERNAME_PASSWORD.lower(): UsernamePasswordConfiguration, + ConnectionAuthType.SERVICE_PRINCIPAL.lower(): ServicePrincipalConfiguration, + ConnectionAuthType.MANAGED_IDENTITY.lower(): ManagedIdentityConfiguration, + ConnectionAuthType.API_KEY.lower(): ApiKeyConfiguration, + } + if not auth_type: + return NoneCredentialConfiguration + return REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP.get( + _snake_to_camel(auth_type).lower(), NoneCredentialConfiguration + ) + class AccountKeyConfiguration(RestTranslatableMixin, DictMixin): def __init__( @@ -130,6 +166,10 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return not self.__eq__(other) + @classmethod + def _get_rest_properties_class(cls) -> Type: + return SASAuthTypeWorkspaceConnectionProperties + class PatTokenConfiguration(RestTranslatableMixin, DictMixin): """Personal access token credentials. @@ -166,6 +206,10 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.pat == other.pat + @classmethod + def _get_rest_properties_class(cls) -> Type: + return PATAuthTypeWorkspaceConnectionProperties + class UsernamePasswordConfiguration(RestTranslatableMixin, DictMixin): """Username and password credentials. @@ -204,6 +248,10 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.username == other.username and self.password == other.password + @classmethod + def _get_rest_properties_class(cls) -> Type: + return UsernamePasswordAuthTypeWorkspaceConnectionProperties + class BaseTenantCredentials(RestTranslatableMixin, DictMixin, ABC): """Base class for tenant credentials. @@ -307,6 +355,10 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return not self.__eq__(other) + @classmethod + def _get_rest_properties_class(cls) -> Type: + return ServicePrincipalAuthTypeWorkspaceConnectionProperties + class CertificateConfiguration(BaseTenantCredentials): def __init__( @@ -524,6 +576,10 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.client_id == other.client_id and self.resource_id == other.resource_id + @classmethod + def _get_rest_properties_class(cls) -> Type: + return ManagedIdentityAuthTypeWorkspaceConnectionProperties + class UserIdentityConfiguration(_BaseIdentityConfiguration): """User identity configuration. @@ -763,6 +819,10 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return not self.__eq__(other) + @classmethod + def _get_rest_properties_class(cls) -> Type: + return NoneAuthTypeWorkspaceConnectionProperties + class AccessKeyConfiguration(RestTranslatableMixin, DictMixin): """Access Key Credentials. @@ -803,6 +863,9 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.access_key_id == other.access_key_id and self.secret_access_key == other.secret_access_key + def _get_rest_properties_class(self): + return AccessKeyAuthTypeWorkspaceConnectionProperties + @experimental class ApiKeyConfiguration(RestTranslatableMixin, DictMixin): @@ -838,3 +901,6 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, ApiKeyConfiguration): return NotImplemented return bool(self.key == other.key) + + def _get_rest_properties_class(self): + return ApiKeyAuthWorkspaceConnectionProperties diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/credentials.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/credentials.py deleted file mode 100644 index 478159c655ab..000000000000 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/credentials.py +++ /dev/null @@ -1,197 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- - -from azure.ai.ml._restclient.v2022_01_01_preview.models import ( - ConnectionAuthType, - ManagedIdentity, - ServicePrincipal, - SharedAccessSignature, - UsernamePassword, -) -from azure.ai.ml.entities._mixins import RestTranslatableMixin - - -class WorkspaceConnectionCredentials(RestTranslatableMixin): - """Base class for workspace connection credentials.""" - - def __init__(self) -> None: - self.type = None - - -class SasTokenCredentials(WorkspaceConnectionCredentials): - """Shared Access Signatures Token Credentials. - - :param sas: sas token - :type sas: str - """ - - def __init__(self, *, sas: str): - super().__init__() - self.type = ConnectionAuthType.SAS - self.sas = sas - - def _to_rest_object(self) -> SharedAccessSignature: - return SharedAccessSignature(sas=self.sas) - - @classmethod - def _from_rest_object(cls, obj: SharedAccessSignature) -> "SasTokenCredentials": - return cls(sas=obj.sas if obj.sas else None) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SasTokenCredentials): - return NotImplemented - return self.sas == other.sas - - -class UsernamePasswordCredentials(WorkspaceConnectionCredentials): - """Username Password Credentials. - - :param username: username - :type username: str - :param password: password - :type password: str - """ - - def __init__( - self, - *, - username: str, - password: str, - ): - super().__init__() - self.type = ConnectionAuthType.USERNAME_PASSWORD - self.username = username - self.password = password - - def _to_rest_object(self) -> UsernamePassword: - return UsernamePassword(username=self.username, password=self.password) - - @classmethod - def _from_rest_object(cls, obj: UsernamePassword) -> "UsernamePasswordCredentials": - return cls( - username=obj.username if obj.username else None, - password=obj.password if obj.password else None, - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, UsernamePasswordCredentials): - return NotImplemented - return self.username == other.username and self.password == other.password - - -class ManagedIdentityCredentials(WorkspaceConnectionCredentials): - """Managed Identity Credentials. - - :param client_id: client id, should be guid - :type client_id: str - :param resource_id: resource id - :type resource_id: str - """ - - def __init__(self, *, client_id: str, resource_id: str): - super().__init__() - self.type = ConnectionAuthType.MANAGED_IDENTITY - self.client_id = client_id - # TODO: Check if both client_id and resource_id are required - self.resource_id = resource_id - - def _to_rest_object(self) -> ManagedIdentity: - return ManagedIdentity(client_id=self.client_id, resource_id=self.resource_id) - - @classmethod - def _from_rest_object(cls, obj: ManagedIdentity) -> "ManagedIdentityCredentials": - return cls( - client_id=obj.client_id, - resource_id=obj.resource_id, - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ManagedIdentityCredentials): - return NotImplemented - return self.client_id == other.client_id and self.resource_id == other.resource_id - - -class ServicePrincipalCredentials(WorkspaceConnectionCredentials): - """Service Principal Credentials. - - :param client_id: client id, should be guid - :type client_id: str - :param client_secret: client secret - :type client_secret: str - :param tenant_id: tenant id, should be guid - :type tenant_id: str - """ - - def __init__( - self, - *, - client_id: str, - client_secret: str, - tenant_id: str, - ): - super().__init__() - self.type = ConnectionAuthType.SERVICE_PRINCIPAL - self.client_id = client_id - self.client_secret = client_secret - self.tenant_id = tenant_id - - def _to_rest_object(self) -> ServicePrincipal: - return ServicePrincipal( - client_id=self.client_id, - client_secret=self.client_secret, - tenant_id=self.tenant_id, - ) - - @classmethod - def _from_rest_object(cls, obj: ServicePrincipal) -> "ServicePrincipalCredentials": - return cls( - client_id=obj.client_id if obj.client_id else None, - client_secret=obj.client_secret if obj.client_secret else None, - tenant_id=obj.tenant_id if obj.tenant_id else None, - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ServicePrincipalCredentials): - return NotImplemented - return ( - self.client_id == other.client_id - and self.client_secret == other.client_secret - and self.tenant_id == other.tenant_id - ) - - -class AccessKeyCredentials(WorkspaceConnectionCredentials): - """Access Key Credentials. - - :param access_key_id: access key id - :type access_key_id: str - :param secret_access_key: secret access key - :type secret_access_key: str - """ - - def __init__( - self, - *, - access_key_id: str, - secret_access_key: str, - ): - super().__init__() - self.type = ConnectionAuthType.ACCESS_KEY - self.access_key_id = access_key_id - self.secret_access_key = secret_access_key - - def _to_rest_object(self) -> UsernamePassword: - return UsernamePassword(username=self.access_key_id, password=self.secret_access_key) - - @classmethod - def _from_rest_object(cls, obj: UsernamePassword) -> "AccessKeyCredentials": - return cls( - access_key_id=obj.username if obj.username else None, - secret_access_key=obj.password if obj.password else None, - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AccessKeyCredentials): - return NotImplemented - return self.access_key_id == other.access_key_id and self.secret_access_key == other.secret_access_key diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection.py index aaa8b833f78c..5469f85d5acb 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection.py @@ -6,36 +6,21 @@ from os import PathLike from pathlib import Path +import warnings from typing import IO, Any, AnyStr, Dict, List, Optional, Type, Union, cast -from azure.ai.ml._restclient.v2023_08_01_preview.models import ( - AccessKeyAuthTypeWorkspaceConnectionProperties, - ApiKeyAuthWorkspaceConnectionProperties, - ConnectionAuthType, +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( ConnectionCategory, - ManagedIdentityAuthTypeWorkspaceConnectionProperties, NoneAuthTypeWorkspaceConnectionProperties, - PATAuthTypeWorkspaceConnectionProperties, - SASAuthTypeWorkspaceConnectionProperties, - ServicePrincipalAuthTypeWorkspaceConnectionProperties, - UsernamePasswordAuthTypeWorkspaceConnectionProperties, ) from azure.ai.ml._restclient.v2023_08_01_preview.models import ( WorkspaceConnectionPropertiesV2BasicResource as RestWorkspaceConnection, ) from azure.ai.ml._schema.workspace.connections.workspace_connection import WorkspaceConnectionSchema -from azure.ai.ml._schema.workspace.connections.workspace_connection_subtypes import ( - AzureAISearchWorkspaceConnectionSchema, - AzureAIServiceWorkspaceConnectionSchema, - OpenAIWorkspaceConnectionSchema, -) from azure.ai.ml._utils._experimental import experimental from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake, dump_yaml_to_file from azure.ai.ml.constants._common import ( BASE_PATH_CONTEXT_KEY, - CONNECTION_API_TYPE_KEY, - CONNECTION_API_VERSION_KEY, - CONNECTION_KIND_KEY, PARAMS_OVERRIDE_KEY, WorkspaceConnectionTypes, ) @@ -47,6 +32,8 @@ SasTokenConfiguration, ServicePrincipalConfiguration, UsernamePasswordConfiguration, + NoneCredentialConfiguration, + _BaseIdentityConfiguration, ) from azure.ai.ml.entities._resource import Resource from azure.ai.ml.entities._system_data import SystemData @@ -73,11 +60,10 @@ class WorkspaceConnection(Resource): :param type: The category of external resource for this connection. :type type: The type of workspace connection, possible values are: "git", "python_feed", "container_registry", "feature_store", "s3", "snowflake", "azure_sql_db", "azure_synapse_analytics", "azure_my_sql_db", - "azure_postgres_db", "custom". + "azure_postgres_db", "adls_gen_2", "azure_one_lake", "custom". :param credentials: The credentials for authenticating to the external resource. Note that certain connection types (as defined by the type input) only accept certain types of credentials. :type credentials: Union[ - ~azure.ai.ml.entities.PatTokenConfiguration, ~azure.ai.ml.entities.SasTokenConfiguration, ~azure.ai.ml.entities.UsernamePasswordConfiguration, @@ -109,6 +95,26 @@ def __init__( is_shared: bool = True, **kwargs: Any, ): + + # Dev note: This initializer has an undocumented kwarg "from_child" to determine if this initialization + # is from a child class. + # This kwarg is required to allow instantiation of types that are associated with subtypes without a + # warning printout. + # The additional undocumented kwarg "strict_typing" turns the warning into a value error. + from_child = kwargs.pop("from_child", False) + strict_typing = kwargs.pop("strict_typing", False) + correct_class = WorkspaceConnection._get_entity_class_from_type(type) + if not from_child and correct_class != WorkspaceConnection: + if strict_typing: + raise ValueError( + f"Cannot instantiate a base WorkspaceConnection with a type of {type}. " + f"Please use the appropriate subclass {correct_class.__name__} instead." + ) + warnings.warn( + f"The workspace connection of {type} has additional fields and should not be instantiated directly " + f"from the WorkspaceConnection class. Please use its subclass {correct_class.__name__} instead.", + ) + super().__init__(**kwargs) self.type = type @@ -168,7 +174,6 @@ def credentials( ~azure.ai.ml.entities.ServicePrincipalConfiguration, ~azure.ai.ml.entities.AccessKeyConfiguration, ~azure.ai.ml.entities.ApiKeyConfiguration - ] """ return self._credentials @@ -262,29 +267,17 @@ def _to_dict(self) -> Dict: @classmethod def _from_rest_object(cls, rest_obj: RestWorkspaceConnection) -> Optional["WorkspaceConnection"]: - from .workspace_connection_subtypes import ( - AzureAISearchWorkspaceConnection, - AzureAIServiceWorkspaceConnection, - AzureOpenAIWorkspaceConnection, - ) - if not rest_obj: return None conn_cat = rest_obj.properties.category conn_class = cls._get_entity_class_from_type(conn_cat) - popped_tags = [] - if conn_class == AzureOpenAIWorkspaceConnection: - popped_tags = [CONNECTION_API_VERSION_KEY, CONNECTION_API_TYPE_KEY] - elif conn_class == AzureAISearchWorkspaceConnection: - popped_tags = [CONNECTION_API_VERSION_KEY] - elif conn_class == AzureAIServiceWorkspaceConnection: - popped_tags = [CONNECTION_API_VERSION_KEY, CONNECTION_KIND_KEY] + popped_tags = conn_class._get_required_metadata_fields() rest_kwargs = cls._extract_kwargs_from_rest_obj(rest_obj=rest_obj, popped_tags=popped_tags) - # Renaming for client clarity - if rest_kwargs["type"] == camel_to_snake(ConnectionCategory.CUSTOM_KEYS): + # Check for alternative name for custom connection type (added for client clarity). + if rest_kwargs["type"].lower() == camel_to_snake(ConnectionCategory.CUSTOM_KEYS).lower(): rest_kwargs["type"] = WorkspaceConnectionTypes.CUSTOM workspace_connection = conn_class(**rest_kwargs) return cast(Optional["WorkspaceConnection"], workspace_connection) @@ -293,38 +286,30 @@ def _validate(self) -> str: return str(self.name) def _to_rest_object(self) -> RestWorkspaceConnection: - workspace_connection_properties_class: Any = None - auth_type = self.credentials.type if self._credentials else None - - if auth_type == camel_to_snake(ConnectionAuthType.PAT): - workspace_connection_properties_class = PATAuthTypeWorkspaceConnectionProperties - elif auth_type == camel_to_snake(ConnectionAuthType.MANAGED_IDENTITY): - workspace_connection_properties_class = ManagedIdentityAuthTypeWorkspaceConnectionProperties - elif auth_type == camel_to_snake(ConnectionAuthType.USERNAME_PASSWORD): - workspace_connection_properties_class = UsernamePasswordAuthTypeWorkspaceConnectionProperties - elif auth_type == camel_to_snake(ConnectionAuthType.ACCESS_KEY): - workspace_connection_properties_class = AccessKeyAuthTypeWorkspaceConnectionProperties - elif auth_type == camel_to_snake(ConnectionAuthType.SAS): - workspace_connection_properties_class = SASAuthTypeWorkspaceConnectionProperties - elif auth_type == camel_to_snake(ConnectionAuthType.SERVICE_PRINCIPAL): - workspace_connection_properties_class = ServicePrincipalAuthTypeWorkspaceConnectionProperties - elif auth_type == camel_to_snake(ConnectionAuthType.API_KEY): - workspace_connection_properties_class = ApiKeyAuthWorkspaceConnectionProperties - elif auth_type is None: - workspace_connection_properties_class = NoneAuthTypeWorkspaceConnectionProperties - - # Convert from human readable to api enums if needed. + workspace_connection_properties_class: Any = NoneAuthTypeWorkspaceConnectionProperties + if self._credentials: + workspace_connection_properties_class = self._credentials._get_rest_properties_class() + # Convert from human readable type to corresponding api enum if needed. conn_type = self.type if conn_type == WorkspaceConnectionTypes.CUSTOM: conn_type = ConnectionCategory.CUSTOM_KEYS - properties = workspace_connection_properties_class( - target=self.target, - credentials=self.credentials._to_workspace_connection_rest_object(), - metadata=self.tags, - category=_snake_to_camel(conn_type), - is_shared_to_all=self.is_shared, - ) + # No credential property bag uniquely has different inputs from ALL other property bag classes. + if workspace_connection_properties_class == NoneAuthTypeWorkspaceConnectionProperties: + properties = workspace_connection_properties_class( + target=self.target, + metadata=self.tags, + category=_snake_to_camel(conn_type), + is_shared_to_all=self.is_shared, + ) + else: + properties = workspace_connection_properties_class( + target=self.target, + credentials=self.credentials._to_workspace_connection_rest_object() if self._credentials else None, + metadata=self.tags, + category=_snake_to_camel(conn_type), + is_shared_to_all=self.is_shared, + ) return RestWorkspaceConnection(properties=properties) @@ -346,20 +331,10 @@ def _extract_kwargs_from_rest_obj(cls, rest_obj: RestWorkspaceConnection, popped """ properties = rest_obj.properties credentials: Any = None - if properties.auth_type == ConnectionAuthType.PAT: - credentials = PatTokenConfiguration._from_workspace_connection_rest_object(properties.credentials) - elif properties.auth_type == ConnectionAuthType.SAS: - credentials = SasTokenConfiguration._from_workspace_connection_rest_object(properties.credentials) - elif properties.auth_type == ConnectionAuthType.MANAGED_IDENTITY: - credentials = ManagedIdentityConfiguration._from_workspace_connection_rest_object(properties.credentials) - elif properties.auth_type == ConnectionAuthType.USERNAME_PASSWORD: - credentials = UsernamePasswordConfiguration._from_workspace_connection_rest_object(properties.credentials) - elif properties.auth_type == ConnectionAuthType.ACCESS_KEY: - credentials = AccessKeyConfiguration._from_workspace_connection_rest_object(properties.credentials) - elif properties.auth_type == ConnectionAuthType.SERVICE_PRINCIPAL: - credentials = ServicePrincipalConfiguration._from_workspace_connection_rest_object(properties.credentials) - elif properties.auth_type == ConnectionAuthType.API_KEY: - credentials = ApiKeyConfiguration._from_workspace_connection_rest_object(properties.credentials) + + credentials_class = _BaseIdentityConfiguration._get_credential_class_from_rest_type(properties.auth_type) + if credentials_class is not NoneCredentialConfiguration: + credentials = credentials_class._from_workspace_connection_rest_object(properties.credentials) tags = properties.metadata if hasattr(properties, "metadata") else None rest_kwargs = { @@ -384,7 +359,9 @@ def _get_entity_class_from_type(cls, conn_type: Optional[str]) -> Type: workspace connection class or subclass. Accounts for potential snake/camel case and capitalization differences. - :param conn_type: The connection type. + :param conn_type: The desired connection type represented as a string. This + should align with the 'connection_category' enum field defined in the rest client, but + can be in snake or camel case. :type conn_type: str :return: The workspace connection class the conn_type corresponds to. @@ -392,22 +369,24 @@ def _get_entity_class_from_type(cls, conn_type: Optional[str]) -> Type: """ if conn_type is None: return WorkspaceConnection - # done here to avoid circular imports on load + # Imports are done here to avoid circular imports on load. from .workspace_connection_subtypes import ( AzureAISearchWorkspaceConnection, AzureAIServiceWorkspaceConnection, AzureOpenAIWorkspaceConnection, + AzureBlobStoreWorkspaceConnection, ) - cat = camel_to_snake(conn_type).lower() - conn_class: Type = WorkspaceConnection - if cat == camel_to_snake(ConnectionCategory.AZURE_OPEN_AI).lower(): - conn_class = AzureOpenAIWorkspaceConnection - elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SEARCH).lower(): - conn_class = AzureAISearchWorkspaceConnection - elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SERVICE).lower(): - conn_class = AzureAIServiceWorkspaceConnection - return conn_class + # Connection categories don't perfectly follow perfect camel casing, so lower + # case everything to avoid problems. + CONNECTION_CATEGORY_TO_SUBCLASS_MAP = { + ConnectionCategory.AZURE_OPEN_AI.lower(): AzureOpenAIWorkspaceConnection, + ConnectionCategory.COGNITIVE_SEARCH.lower(): AzureAISearchWorkspaceConnection, + ConnectionCategory.COGNITIVE_SERVICE.lower(): AzureAIServiceWorkspaceConnection, + ConnectionCategory.AZURE_BLOB.lower(): AzureBlobStoreWorkspaceConnection, + } + cat = _snake_to_camel(conn_type).lower() + return CONNECTION_CATEGORY_TO_SUBCLASS_MAP.get(cat, WorkspaceConnection) @classmethod def _get_schema_class_from_type(cls, conn_type: Optional[str]) -> Type: @@ -422,15 +401,28 @@ def _get_schema_class_from_type(cls, conn_type: Optional[str]) -> Type: :rtype: Type """ if conn_type is None: - return WorkspaceConnection + return WorkspaceConnectionSchema + entity_class = cls._get_entity_class_from_type(conn_type) + return entity_class._get_schema_class() - cat = camel_to_snake(conn_type).lower() - conn_class: Type = WorkspaceConnectionSchema - if cat == camel_to_snake(ConnectionCategory.AZURE_OPEN_AI).lower(): - conn_class = OpenAIWorkspaceConnectionSchema - elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SEARCH).lower(): - conn_class = AzureAISearchWorkspaceConnectionSchema - elif cat == camel_to_snake(ConnectionCategory.COGNITIVE_SERVICE).lower(): - conn_class = AzureAIServiceWorkspaceConnectionSchema + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + """Helper function that returns the required metadata fields for specific workspace + connection type. This parent function returns nothing, but needs to be overwritten by child + classes, which are created under the expectation that they have extra fields that need to be + accounted for. + + :return: A list of the required metadata fields for the specific workspace connection type. + :rtype: List[str] + """ + return [] + + @classmethod + def _get_schema_class(cls) -> Type: + """Helper function that maps this class to its associated schema class. Needs to be overridden by + child classes to allow the base class to be polymorphic in its schema reading. - return conn_class + :return: The appropriate schema class to use with this entity class. + :rtype: Type + """ + return WorkspaceConnectionSchema diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection_subtypes.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection_subtypes.py index c92bf45a23ac..22048fd6de5a 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection_subtypes.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/workspace_connection_subtypes.py @@ -4,15 +4,27 @@ # pylint: disable=protected-access -from typing import Any, Optional +from typing import Any, List, Optional, Type +from azure.ai.ml._utils.utils import camel_to_snake from azure.ai.ml._utils._experimental import experimental -from azure.ai.ml.constants._common import CONNECTION_API_TYPE_KEY, CONNECTION_API_VERSION_KEY, CONNECTION_KIND_KEY +from azure.ai.ml.constants._common import ( + CONNECTION_API_VERSION_KEY, + CONNECTION_API_TYPE_KEY, + CONNECTION_KIND_KEY, + CONNECTION_ACCOUNT_NAME_KEY, + CONNECTION_CONTAINER_NAME_KEY, +) from azure.ai.ml.entities._credentials import ApiKeyConfiguration - +from azure.ai.ml._schema.workspace.connections.workspace_connection_subtypes import ( + AzureAISearchWorkspaceConnectionSchema, + AzureAIServiceWorkspaceConnectionSchema, + OpenAIWorkspaceConnectionSchema, + AzureBlobStoreWorkspaceConnectionSchema, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import ConnectionCategory from .workspace_connection import WorkspaceConnection - # Dev notes: Any new classes require modifying the elif chains in the following functions in the # WorkspaceConnection parent class: _from_rest_object, _get_entity_class_from_type, _get_schema_class_from_type @experimental @@ -44,18 +56,33 @@ def __init__( **kwargs: Any, ): kwargs.pop("type", None) # make sure we never somehow use wrong type - super().__init__(target=target, type="azure_open_ai", credentials=credentials, **kwargs) - - if self.tags is not None: - self.tags[CONNECTION_API_VERSION_KEY] = api_version - self.tags[CONNECTION_API_TYPE_KEY] = api_type + super().__init__( + target=target, + type=camel_to_snake(ConnectionCategory.AZURE_OPEN_AI), + credentials=credentials, + from_child=True, + **kwargs, + ) + + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_API_VERSION_KEY] = api_version + self.tags[CONNECTION_API_TYPE_KEY] = api_type + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + return [CONNECTION_API_VERSION_KEY, CONNECTION_API_TYPE_KEY] + + @classmethod + def _get_schema_class(cls) -> Type: + return OpenAIWorkspaceConnectionSchema @property def api_version(self) -> Optional[str]: """The API version of the workspace connection. :return: The API version of the workspace connection. - :rtype: str + :rtype: Optional[str] """ if self.tags is not None and CONNECTION_API_VERSION_KEY in self.tags: res: str = self.tags[CONNECTION_API_VERSION_KEY] @@ -69,15 +96,16 @@ def api_version(self, value: str) -> None: :param value: The new api version to set. :type value: str """ - if self.tags is not None: - self.tags[CONNECTION_API_VERSION_KEY] = value + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_API_VERSION_KEY] = value @property def api_type(self) -> Optional[str]: """The API type of the workspace connection. :return: The API type of the workspace connection. - :rtype: str + :rtype: Optional[str] """ if self.tags is not None and CONNECTION_API_TYPE_KEY in self.tags: res: str = self.tags[CONNECTION_API_TYPE_KEY] @@ -91,8 +119,9 @@ def api_type(self, value: str) -> None: :param value: The new api type to set. :type value: str """ - if self.tags is not None: - self.tags[CONNECTION_API_TYPE_KEY] = value + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_API_TYPE_KEY] = value @experimental @@ -121,17 +150,32 @@ def __init__( **kwargs: Any, ): kwargs.pop("type", None) # make sure we never somehow use wrong type - super().__init__(target=target, type="cognitive_search", credentials=credentials, **kwargs) - - if self.tags is not None: - self.tags[CONNECTION_API_VERSION_KEY] = api_version + super().__init__( + target=target, + type=camel_to_snake(ConnectionCategory.COGNITIVE_SEARCH), + credentials=credentials, + from_child=True, + **kwargs, + ) + + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_API_VERSION_KEY] = api_version + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + return [CONNECTION_API_VERSION_KEY] + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureAISearchWorkspaceConnectionSchema @property def api_version(self) -> Optional[str]: """The API version of the workspace connection. :return: The API version of the workspace connection. - :rtype: str + :rtype: Optional[str] """ if self.tags is not None and CONNECTION_API_VERSION_KEY in self.tags: res: str = self.tags[CONNECTION_API_VERSION_KEY] @@ -145,8 +189,9 @@ def api_version(self, value: str) -> None: :param value: The new api version to set. :type value: str """ - if self.tags is not None: - self.tags[CONNECTION_API_VERSION_KEY] = value + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_API_VERSION_KEY] = value @experimental @@ -178,18 +223,33 @@ def __init__( **kwargs: Any, ): kwargs.pop("type", None) # make sure we never somehow use wrong type - super().__init__(target=target, type="cognitive_service", credentials=credentials, **kwargs) - - if self.tags is not None: - self.tags[CONNECTION_API_VERSION_KEY] = api_version - self.tags[CONNECTION_KIND_KEY] = kind + super().__init__( + target=target, + type=camel_to_snake(ConnectionCategory.COGNITIVE_SERVICE), + credentials=credentials, + from_child=True, + **kwargs, + ) + + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_API_VERSION_KEY] = api_version + self.tags[CONNECTION_KIND_KEY] = kind + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + return [CONNECTION_API_VERSION_KEY, CONNECTION_KIND_KEY] + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureAIServiceWorkspaceConnectionSchema @property def api_version(self) -> Optional[str]: """The API version of the workspace connection. :return: The API version of the workspace connection. - :rtype: str + :rtype: Optional[str] """ if self.tags is not None and CONNECTION_API_VERSION_KEY in self.tags: res: str = self.tags[CONNECTION_API_VERSION_KEY] @@ -203,15 +263,16 @@ def api_version(self, value: str) -> None: :param value: The new api version to set. :type value: str """ - if self.tags is not None: - self.tags[CONNECTION_API_VERSION_KEY] = value + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_API_VERSION_KEY] = value @property def kind(self) -> Optional[str]: """The kind of the workspace connection. :return: The kind of the workspace connection. - :rtype: str + :rtype: Optional[str] """ if self.tags is not None and CONNECTION_KIND_KEY in self.tags: res: str = self.tags[CONNECTION_KIND_KEY] @@ -225,5 +286,102 @@ def kind(self, value: str) -> None: :param value: The new kind to set. :type value: str """ + if self.tags is None: + self.tags = {} + self.tags[CONNECTION_KIND_KEY] = value + + +@experimental +class AzureBlobStoreWorkspaceConnection(WorkspaceConnection): + """A connection to an Azure Blob Store. Connections of this type are read-only, + creation operations with them are not supported, and this class is only meant to be + instantiated from GET and LIST workspace connection operations. + + :param name: Name of the workspace connection. + :type name: str + :param target: The URL or ARM resource ID of the external resource. + :type target: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict + :param container_name: The name of the container. + :type container_name: str + :param account_name: The name of the account. + :type account_name: str + """ + + def __init__( + self, + *, + target: str, + container_name: str, + account_name: str, + **kwargs, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + # Blob store connections returned from the API generally have no credentials, but we still don't want + # to silently run over user inputted connections if they want to play with them locally, so double-check + # kwargs for them. + super().__init__( + target=target, + type=camel_to_snake(ConnectionCategory.AZURE_BLOB), + credentials=kwargs.pop("credentials", None), + from_child=True, + **kwargs, + ) + + if not hasattr(self, "tags") or self.tags is None: + self.tags = {} + self.tags[CONNECTION_CONTAINER_NAME_KEY] = container_name + self.tags[CONNECTION_ACCOUNT_NAME_KEY] = account_name + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + return [CONNECTION_CONTAINER_NAME_KEY, CONNECTION_ACCOUNT_NAME_KEY] + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureBlobStoreWorkspaceConnectionSchema + + @property + def container_name(self) -> Optional[str]: + """The name of the workspace connection's container. + + :return: The name of the container. + :rtype: Optional[str] + """ + if self.tags is not None: + return self.tags.get(CONNECTION_CONTAINER_NAME_KEY, None) + return None + + @container_name.setter + def container_name(self, value: str) -> None: + """Set the container name of the workspace connection. + + :param value: The new container name to set. + :type value: str + """ + if self.tags is None: + self.tags = {} + self.tags[CONNECTION_CONTAINER_NAME_KEY] = value + + @property + def account_name(self) -> Optional[str]: + """The name of the workspace connection's account + + :return: The name of the account. + :rtype: Optional[str] + """ if self.tags is not None: - self.tags[CONNECTION_KIND_KEY] = value + return self.tags.get(CONNECTION_ACCOUNT_NAME_KEY, None) + return None + + @account_name.setter + def account_name(self, value: str) -> None: + """Set the account name of the workspace connection. + + :param value: The new account name to set. + :type value: str + """ + if self.tags is None: + self.tags = {} + self.tags[CONNECTION_ACCOUNT_NAME_KEY] = value diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py index f250a8e9a771..bf0af76a7709 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py @@ -132,11 +132,15 @@ def delete(self, name: str) -> None: def list( self, connection_type: Optional[str] = None, + include_data_connections: bool = False, + **kwargs: Any, ) -> Iterable[WorkspaceConnection]: """List all workspace connections for a workspace. :param connection_type: Type of workspace connection to list. :type connection_type: Optional[str] + :param include_data_connections: If true, also return data connections. Defaults to False. + :type include_data_connections: bool :return: An iterator like instance of workspace connection objects :rtype: Iterable[~azure.ai.ml.entities.WorkspaceConnection] @@ -149,6 +153,13 @@ def list( :dedent: 8 :caption: Lists all connections for a workspace for a certain type, in this case "git". """ + + if include_data_connections: + if "params" in kwargs: + kwargs["params"]["includeAll"] = "true" + else: + kwargs["params"] = {"includeAll": "true"} + return cast( Iterable[WorkspaceConnection], self._operation.list( @@ -156,5 +167,6 @@ def list( cls=lambda objs: [WorkspaceConnection._from_rest_object(obj) for obj in objs], category=_snake_to_camel(connection_type) if connection_type else connection_type, **self._scope_kwargs, + **kwargs, ), ) diff --git a/sdk/ml/azure-ai-ml/tests/conftest.py b/sdk/ml/azure-ai-ml/tests/conftest.py index 39f5dc5ee871..1d12b1b37b99 100644 --- a/sdk/ml/azure-ai-ml/tests/conftest.py +++ b/sdk/ml/azure-ai-ml/tests/conftest.py @@ -1101,3 +1101,51 @@ def materialization_identity_client_id(request): def default_storage_account(request): value = request.config.option.default_storage_account return value + + +# Datastore fixtures + + +@pytest.fixture +def blob_store_file() -> str: + return "./tests/test_configs/datastore/blob_store.yml" + + +@pytest.fixture +def blob_store_credential_less_file() -> str: + return "./tests/test_configs/datastore/credential_less_blob_store.yml" + + +@pytest.fixture +def file_store_file() -> str: + return "./tests/test_configs/datastore/file_store.yml" + + +@pytest.fixture +def adls_gen1_file() -> str: + return "./tests/test_configs/datastore/adls_gen1.yml" + + +@pytest.fixture +def adls_gen1_credential_less_file() -> str: + return "./tests/test_configs/datastore/credential_less_adls_gen1.yml" + + +@pytest.fixture +def adls_gen2_file() -> str: + return "./tests/test_configs/datastore/adls_gen2.yml" + + +@pytest.fixture +def adls_gen2_credential_less_file() -> str: + return "./tests/test_configs/datastore/credential_less_adls_gen2.yml" + + +@pytest.fixture +def hdfs_keytab_file() -> str: + return "./tests/test_configs/datastore/hdfs_kerberos_keytab.yml" + + +@pytest.fixture +def hdfs_pw_file() -> str: + return "./tests/test_configs/datastore/hdfs_kerberos_pw.yml" diff --git a/sdk/ml/azure-ai-ml/tests/datastore/e2etests/test_datastore.py b/sdk/ml/azure-ai-ml/tests/datastore/e2etests/test_datastore.py index 770b6d3388e0..812d92f9aaf0 100644 --- a/sdk/ml/azure-ai-ml/tests/datastore/e2etests/test_datastore.py +++ b/sdk/ml/azure-ai-ml/tests/datastore/e2etests/test_datastore.py @@ -10,51 +10,6 @@ from azure.ai.ml.entities._datastore.datastore import Datastore -@pytest.fixture -def blob_store_file() -> str: - return "./tests/test_configs/datastore/blob_store.yml" - - -@pytest.fixture -def blob_store_credential_less_file() -> str: - return "./tests/test_configs/datastore/credential_less_blob_store.yml" - - -@pytest.fixture -def file_store_file() -> str: - return "./tests/test_configs/datastore/file_store.yml" - - -@pytest.fixture -def adls_gen1_file() -> str: - return "./tests/test_configs/datastore/adls_gen1.yml" - - -@pytest.fixture -def adls_gen1_credential_less_file() -> str: - return "./tests/test_configs/datastore/credential_less_adls_gen1.yml" - - -@pytest.fixture -def adls_gen2_file() -> str: - return "./tests/test_configs/datastore/adls_gen2.yml" - - -@pytest.fixture -def adls_gen2_credential_less_file() -> str: - return "./tests/test_configs/datastore/credential_less_adls_gen2.yml" - - -@pytest.fixture -def hdfs_keytab_file() -> str: - return "./tests/test_configs/datastore/hdfs_kerberos_keytab.yml" - - -@pytest.fixture -def hdfs_pw_file() -> str: - return "./tests/test_configs/datastore/hdfs_kerberos_pw.yml" - - def b64read(p): from base64 import b64encode diff --git a/sdk/ml/azure-ai-ml/tests/test_configs/workspace_connection/blob_store.yaml b/sdk/ml/azure-ai-ml/tests/test_configs/workspace_connection/blob_store.yaml new file mode 100644 index 000000000000..d0353e073453 --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/test_configs/workspace_connection/blob_store.yaml @@ -0,0 +1,10 @@ +name: test_ws_conn_blob_store +type: azure_blob +target: my_endpoint +tags: + four: five +credentials: + type: api_key + key: "9876" +container_name: some_container +account_name: some_account diff --git a/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace_connections.py b/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace_connections.py index 94680a475820..1141648c11da 100644 --- a/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace_connections.py +++ b/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace_connections.py @@ -1,17 +1,25 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Callable +from typing import Callable, Tuple import pytest from devtools_testutils import AzureRecordedTestCase -from azure.ai.ml import MLClient, load_workspace_connection -from azure.ai.ml._restclient.v2023_06_01_preview.models import ConnectionAuthType, ConnectionCategory +from azure.ai.ml import MLClient, load_workspace_connection, load_datastore +from azure.ai.ml._restclient.v2024_01_01_preview.models import ConnectionAuthType, ConnectionCategory from azure.ai.ml._utils.utils import camel_to_snake -from azure.ai.ml.entities import WorkspaceConnection, Workspace, WorkspaceHub, ApiKeyConfiguration +from azure.ai.ml.entities import ( + WorkspaceConnection, + Workspace, + WorkspaceHub, + ApiKeyConfiguration, + AzureBlobDatastore, + AzureBlobStoreWorkspaceConnection, +) from azure.ai.ml.constants._common import WorkspaceConnectionTypes from azure.core.exceptions import ResourceNotFoundError +from azure.ai.ml.entities._datastore.datastore import Datastore @pytest.mark.xdist_group(name="workspace_connection") @@ -475,3 +483,59 @@ def test_workspace_connection_is_shared_behavior(self, client: MLClient, randstr del_lean1.result() del_lean2.result() client.workspace_hubs.begin_delete(name=hub.name, delete_dependent_resources=True) + + @pytest.mark.shareTest + def test_workspace_connection_data_connection_listing( + self, + client: MLClient, + randstr: Callable[[], str], + account_keys: Tuple[str, str], + blob_store_file: str, + storage_account_name: str, + ) -> None: + # Workspace connections cannot be used to CREATE datastore connections, so we need + # to use normal datastore operations to first make a blob datastore that we can + # then use to test against the data connection listing functionality. + # This test borrows heavily from the datastore e2e tests. + + # Create a blob datastore. + primary_account_key, secondary_account_key = account_keys + random_name = randstr("random_name") + params_override = [ + {"credentials.account_key": primary_account_key}, + {"name": random_name}, + {"account_name": storage_account_name}, + ] + internal_blob_ds = load_datastore(blob_store_file, params_override=params_override) + created_datastore = datastore_create_get_list(client, internal_blob_ds, random_name) + assert isinstance(created_datastore, AzureBlobDatastore) + assert created_datastore.container_name == internal_blob_ds.container_name + assert created_datastore.account_name == internal_blob_ds.account_name + assert created_datastore.credentials.account_key == primary_account_key + + # Make sure that normal list call doesn't include data connection + assert internal_blob_ds.name not in [conn.name for conn in client.connections.list()] + + # Make sure that the data connection list call includes the data connection + found_datastore_conn = False + for conn in client.connections.list(include_data_connections=True): + if created_datastore.name == conn.name: + assert conn.type == camel_to_snake(ConnectionCategory.AZURE_BLOB) + assert isinstance(conn, AzureBlobStoreWorkspaceConnection) + found_datastore_conn = True + # Ensure that we actually found and validated the data connection. + assert found_datastore_conn + # delete the data store. + client.datastores.delete(random_name) + with pytest.raises(Exception): + client.datastores.get(random_name) + + +# Helper function copied from datstore e2e test. +def datastore_create_get_list(client: MLClient, datastore: Datastore, random_name: str) -> Datastore: + client.datastores.create_or_update(datastore) + datastore = client.datastores.get(random_name, include_secrets=True) + assert datastore.name == random_name + ds_list = client.datastores.list() + assert any(ds.name == datastore.name for ds in ds_list) + return datastore diff --git a/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_connection_entity.py b/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_connection_entity.py index 8f826ef668b8..2565fe9138e4 100644 --- a/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_connection_entity.py +++ b/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_connection_entity.py @@ -45,6 +45,7 @@ def simple_workspace_connection_validation(ws_connection): "./tests/test_configs/workspace_connection/git_pat.yaml", ) + def test_managed_identity(self): ws_connection = load_workspace_connection( source="./tests/test_configs/workspace_connection/container_registry_managed_identity.yaml" ) @@ -57,6 +58,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "https://test-feed.com" assert ws_connection.tags == {} + def test_python_feed(self): ws_connection = load_workspace_connection( source="./tests/test_configs/workspace_connection/python_feed_pat.yaml" ) @@ -68,6 +70,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "https://test-feed.com" assert ws_connection.tags == {} + def test_s3_access_key(self): ws_connection = load_workspace_connection(source="./tests/test_configs/workspace_connection/s3_access_key.yaml") assert ws_connection.type == camel_to_snake(ConnectionCategory.S3) @@ -78,6 +81,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "dummy" assert ws_connection.tags == {} + def test_snowflake(self): ws_connection = load_workspace_connection( source="./tests/test_configs/workspace_connection/snowflake_user_pwd.yaml" ) @@ -90,6 +94,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "dummy" assert ws_connection.tags == {} + def test_azure_sql_db(self): ws_connection = load_workspace_connection( source="./tests/test_configs/workspace_connection/azure_sql_db_user_pwd.yaml" ) @@ -102,6 +107,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "dummy" assert ws_connection.tags == {} + def test_synapse_analytics(self): ws_connection = load_workspace_connection( source="./tests/test_configs/workspace_connection/azure_synapse_analytics_user_pwd.yaml" ) @@ -114,6 +120,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "dummy" assert ws_connection.tags == {} + def test_my_sql_db(self): ws_connection = load_workspace_connection( source="./tests/test_configs/workspace_connection/azure_my_sql_db_user_pwd.yaml" ) @@ -126,6 +133,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "dummy" assert ws_connection.tags == {} + def test_postgres_db(self): ws_connection = load_workspace_connection( source="./tests/test_configs/workspace_connection/azure_postgres_db_user_pwd.yaml" ) @@ -138,6 +146,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "dummy" assert ws_connection.tags == {} + def test_open_ai(self): ws_connection = load_workspace_connection(source="./tests/test_configs/workspace_connection/open_ai.yaml") assert ws_connection.type == camel_to_snake(ConnectionCategory.AZURE_OPEN_AI) @@ -150,6 +159,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.api_version == None assert ws_connection.api_type == "Azure" + def test_cog_search(self): ws_connection = load_workspace_connection(source="./tests/test_configs/workspace_connection/cog_search.yaml") assert ws_connection.type == camel_to_snake(ConnectionCategory.COGNITIVE_SEARCH) @@ -160,6 +170,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.tags["ApiVersion"] == "dummy" assert ws_connection.api_version == "dummy" + def test_cog_service(self): ws_connection = load_workspace_connection(source="./tests/test_configs/workspace_connection/cog_service.yaml") assert ws_connection.type == camel_to_snake(ConnectionCategory.COGNITIVE_SERVICE) @@ -172,6 +183,7 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.kind == "some_kind" assert ws_connection.api_version == "dummy" + def test_custom(self): ws_connection = load_workspace_connection(source="./tests/test_configs/workspace_connection/custom.yaml") assert ws_connection.type == camel_to_snake(WorkspaceConnectionTypes.CUSTOM) @@ -181,6 +193,18 @@ def simple_workspace_connection_validation(ws_connection): assert ws_connection.target == "my_endpoint" assert ws_connection.tags["one"] == "two" + def test_blob_storage(self): + ws_connection = load_workspace_connection(source="./tests/test_configs/workspace_connection/blob_store.yaml") + + assert ws_connection.type == camel_to_snake("azure_blob") # TODO replace with connetion category reference + assert ws_connection.credentials.type == camel_to_snake(ConnectionAuthType.API_KEY) + assert ws_connection.credentials.key == "9876" + assert ws_connection.name == "test_ws_conn_blob_store" + assert ws_connection.target == "my_endpoint" + assert ws_connection.tags["four"] == "five" + assert ws_connection.container_name == "some_container" + assert ws_connection.account_name == "some_account" + def test_ws_conn_rest_conversion(self): ws_connection = load_workspace_connection(source="./tests/test_configs/workspace_connection/open_ai.yaml") rest_conn = ws_connection._to_rest_object() @@ -193,3 +217,31 @@ def test_ws_conn_rest_conversion(self): assert ws_connection.api_version == new_ws_conn.api_version assert ws_connection.api_type == new_ws_conn.api_type assert ws_connection.is_shared + + def empty_credentials_rest_conversion(self): + ws_connection = WorkspaceConnection( + target="dummy_target", + type=camel_to_snake(ConnectionCategory.PYTHON_FEED), + credentials=None, + name="dummy_connection", + tags=None, + ) + rest_conn = ws_connection._to_rest_object() + new_ws_conn = AzureOpenAIWorkspaceConnection._from_rest_object(rest_obj=rest_conn) + assert new_ws_conn.credentials == None + + def test_ws_conn_subtype_restriction(self): + with pytest.raises(ValueError): + _ = WorkspaceConnection( + target="dummy_target", + type=ConnectionCategory.AZURE_OPEN_AI, + name="dummy_connection", + strict_typing=True, + credentials=None, + ) + _ = AzureOpenAIWorkspaceConnection( + target="dummy_target", + name="dummy_connection", + strict_typing=True, + credentials=None, + )