diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py index 74ae6e9ae2cc..95b15de57f16 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py @@ -9,6 +9,7 @@ import logging import os import re +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast from marshmallow import INCLUDE, Schema @@ -23,6 +24,7 @@ ) from azure.ai.ml.entities._job.job import Job from azure.ai.ml.entities._job.parallel.run_function import RunFunction +from azure.ai.ml.entities._job.pipeline._io import NodeOutput from ..._schema import PathAwareSchema from ..._utils.utils import is_data_binding_expression @@ -31,7 +33,7 @@ from .._component.component import Component from .._component.flow import FlowComponent from .._component.parallel_component import ParallelComponent -from .._inputs_outputs import Output +from .._inputs_outputs import Input, Output from .._job.job_resource_configuration import JobResourceConfiguration from .._job.parallel.parallel_job import ParallelJob from .._job.parallel.parallel_task import ParallelTask @@ -108,7 +110,7 @@ def __init__( *, component: Union[ParallelComponent, str], compute: Optional[str] = None, - inputs: Optional[Dict] = None, + inputs: Optional[Dict[str, Union[NodeOutput, Input, str, bool, int, float, Enum]]] = None, outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None, retry_settings: Optional[Union[RetrySettings, Dict[str, str]]] = None, logging_level: Optional[str] = None, diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/sweep.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/sweep.py index e79d95e2f2ca..ec1edc2d6a65 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/sweep.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/sweep.py @@ -427,9 +427,9 @@ def __setattr__(self, key: Any, value: Any) -> None: # only one of slack_amount and slack_factor can be specified but default value is 0.0. # Need to keep track of which one is null. if self.early_termination.slack_amount == 0.0: - self.early_termination.slack_amount = None + self.early_termination.slack_amount = None # type: ignore[assignment] if self.early_termination.slack_factor == 0.0: - self.early_termination.slack_factor = None + self.early_termination.slack_factor = None # type: ignore[assignment] @property def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]: diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py index 7eed6e3e3f80..3ad8eb4a8ad5 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_datastore/azure_storage.py @@ -45,7 +45,8 @@ class AzureFileDatastore(Datastore): :param properties: The asset property dictionary. :type properties: dict[str, str] :param credentials: Credentials to use for Azure ML workspace to connect to the storage. - :type credentials: Union[AccountKeySection, SasSection] + :type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration] :param kwargs: A dictionary of additional configuration parameters. :type kwargs: dict """ @@ -61,7 +62,7 @@ def __init__( endpoint: str = _get_storage_endpoint_from_metadata(), protocol: str = HTTPS, properties: Optional[Dict] = None, - credentials: Any, + credentials: Union[AccountKeyConfiguration, SasTokenConfiguration], **kwargs: Any ): kwargs[TYPE] = DatastoreType.AZURE_FILE @@ -97,7 +98,7 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureFileDatas name=datastore_resource.name, id=datastore_resource.id, account_name=properties.account_name, - credentials=from_rest_datastore_credentials(properties.credentials), + credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type] endpoint=properties.endpoint, protocol=properties.protocol, file_share_name=properties.file_share_name, @@ -260,7 +261,7 @@ def __init__( endpoint: str = _get_storage_endpoint_from_metadata(), protocol: str = HTTPS, properties: Optional[Dict] = None, - credentials: Any = None, + credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None, **kwargs: Any ): kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN2 @@ -299,7 +300,7 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeG name=datastore_resource.name, id=datastore_resource.id, account_name=properties.account_name, - credentials=from_rest_datastore_credentials(properties.credentials), + credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type] endpoint=properties.endpoint, protocol=properties.protocol, filesystem=properties.filesystem, diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/external_data.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/external_data.py index ac6c6608238f..a031d065ad27 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/external_data.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/external_data.py @@ -67,7 +67,7 @@ def __init__( query: Optional[str] = None, table_name: Optional[str] = None, stored_procedure: Optional[str] = None, - stored_procedure_params: Optional[List] = None, + stored_procedure_params: Optional[List[Dict]] = None, connection: Optional[str] = None, ) -> None: # As an annotation, it is not allowed to initialize the name. diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/input.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/input.py index 96a2fe233512..b5e11a188b13 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/input.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/input.py @@ -85,19 +85,6 @@ def __init__( ) -> None: """""" - @overload - def __init__( - self, - *, - type: Literal["uri_folder"] = "uri_folder", - path: Optional[str] = None, - mode: Optional[str] = None, - optional: Optional[bool] = None, - description: Optional[str] = None, - **kwargs: Any, - ) -> None: - """""" - @overload def __init__( self, diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/output.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/output.py index bc245126d20a..573931a53243 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/output.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/output.py @@ -34,14 +34,42 @@ def __init__( @overload def __init__( + self, + type: Literal["uri_file"] = "uri_file", + path: Optional[str] = None, + mode: Optional[str] = None, + description: Optional[str] = None, + ): + """Define a URI file output. + + :keyword type: The type of the data output. Can only be set to 'uri_file'. + :paramtype type: str + :keyword path: The remote path where the output should be stored. + :paramtype path: str + :keyword mode: The access mode of the data output. Accepted values are + * 'rw_mount': Read-write mount the data, + * 'upload': Upload the data from the compute target, + * 'direct': Pass in the URI as a string + :paramtype mode: str + :keyword description: The description of the output. + :paramtype description: str + :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without + setting a version. + :paramtype name: str + :keyword version: The version used to register the output as a Data or Model asset. A version can be set only + when name is set. + :paramtype version: str + """ + + def __init__( # type: ignore[misc] self, *, - type: Literal["uri_folder"] = "uri_folder", + type: str = AssetTypes.URI_FOLDER, path: Optional[str] = None, mode: Optional[str] = None, description: Optional[str] = None, + **kwargs: Any, ) -> None: - # pylint: disable=line-too-long """Define an output. :keyword type: The type of the data output. Accepted values are 'uri_folder', 'uri_file', 'mltable', @@ -68,9 +96,8 @@ def __init__( :paramtype early_available: bool :keyword intellectual_property: Intellectual property associated with the output. It can be an instance of `IntellectualProperty` or a dictionary that will be used to create an instance. - :paramtype intellectual_property: Union[~azure.ai.ml.entities._assets.intellectual_property.IntellectualProperty, - - dict] + :paramtype intellectual_property: Union[ + ~azure.ai.ml.entities._assets.intellectual_property.IntellectualProperty, dict] .. admonition:: Example: @@ -81,46 +108,6 @@ def __init__( :dedent: 8 :caption: Creating a CommandJob with a folder output. """ - - @overload - def __init__( - self, - *, - type: Literal["uri_file"] = "uri_file", - path: Optional[str] = None, - mode: Optional[str] = None, - description: Optional[str] = None, - ): - """Define a URI file output. - - :keyword type: The type of the data output. Can only be set to 'uri_file'. - :paramtype type: str - :keyword path: The remote path where the output should be stored. - :paramtype path: str - :keyword mode: The access mode of the data output. Accepted values are - * 'rw_mount': Read-write mount the data, - * 'upload': Upload the data from the compute target, - * 'direct': Pass in the URI as a string - :paramtype mode: str - :keyword description: The description of the output. - :paramtype description: str - :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without - setting a version. - :paramtype name: str - :keyword version: The version used to register the output as a Data or Model asset. A version can be set only - when name is set. - :paramtype version: str - """ - - def __init__( - self, - *, - type: str = AssetTypes.URI_FOLDER, - path: Optional[str] = None, - mode: Optional[str] = None, - description: Optional[str] = None, - **kwargs: Any, - ) -> None: super(Output, self).__init__(type=type) # As an annotation, it is not allowed to initialize the _port_name. self._port_name = None diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job.py index 1e3cd393a892..d756d9a26964 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job.py @@ -141,7 +141,7 @@ def status(self) -> Optional[str]: return self._status @property - def log_files(self) -> Optional[Dict]: + def log_files(self) -> Optional[Dict[str, str]]: """Job output files. :return: The dictionary of log names and URLs. @@ -208,10 +208,10 @@ def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] from azure.ai.ml.entities._builders.command import Command from azure.ai.ml.entities._builders.spark import Spark from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob from azure.ai.ml.entities._job.import_job import ImportJob from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob - from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob job_type: Optional[Type["Job"]] = None type_in_override = find_type_in_override(params_override) @@ -293,9 +293,9 @@ def _from_rest_object( # pylint: disable=too-many-return-statements from azure.ai.ml.entities._builders.spark import Spark from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob from azure.ai.ml.entities._job.base_job import _BaseJob + from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob from azure.ai.ml.entities._job.import_job import ImportJob from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob - from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob try: if isinstance(obj, Run): diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/parallel/parallel_job.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/parallel/parallel_job.py index 9310f65ab0db..b1d2b74f6ff1 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/parallel/parallel_job.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/parallel/parallel_job.py @@ -200,7 +200,7 @@ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Parallel": component=component, compute=self.compute, # Need to supply the inputs with double curly. - inputs=self.inputs, + inputs=self.inputs, # type: ignore[arg-type] outputs=self.outputs, # type: ignore[arg-type] mini_batch_size=self.mini_batch_size, partition_keys=self.partition_keys, diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/sweep/early_termination_policy.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/sweep/early_termination_policy.py index 08e1274e573b..b1b928fcc4a4 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/sweep/early_termination_policy.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/sweep/early_termination_policy.py @@ -77,8 +77,8 @@ def __init__( *, delay_evaluation: int = 0, evaluation_interval: int = 0, - slack_amount: Optional[float] = 0, - slack_factor: Optional[float] = 0, + slack_amount: float = 0, + slack_factor: float = 0, ) -> None: super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval) self.type = EarlyTerminationPolicyType.BANDIT.lower() diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_set_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_set_operations.py index 98cfc0a94951..4703ff800c10 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_set_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_set_operations.py @@ -125,7 +125,7 @@ def _get(self, name: str, version: Optional[str] = None, **kwargs: Dict) -> Feat @distributed_trace @monitor_with_activity(ops_logger, "FeatureSet.Get", ActivityType.PUBLICAPI) - def get(self, name: str, version: str, **kwargs: Dict) -> Optional[FeatureSet]: # type: ignore + def get(self, name: str, version: str, **kwargs: Dict) -> FeatureSet: # type: ignore """Get the specified FeatureSet asset. :param name: Name of FeatureSet asset. @@ -134,6 +134,8 @@ def get(self, name: str, version: str, **kwargs: Dict) -> Optional[FeatureSet]: :type version: str :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSet cannot be successfully identified and retrieved. Details will be provided in the error message. + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. :return: FeatureSet asset object. :rtype: ~azure.ai.ml.entities.FeatureSet """ diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_store_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_store_operations.py index 571f5d83c8e1..fdadea599e2a 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_store_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_feature_store_operations.py @@ -118,6 +118,8 @@ def get(self, name: str, **kwargs: Any) -> FeatureStore: :param name: Name of the feature store. :type name: str + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. :return: The feature store with the provided name. :rtype: FeatureStore """ diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py index 8a3e9e34c810..f28a5617f5d3 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py @@ -365,7 +365,7 @@ def get(self, name: str) -> Job: @distributed_trace @monitor_with_telemetry_mixin(ops_logger, "Job.ShowServices", ActivityType.PUBLICAPI) - def show_services(self, name: str, node_index: int = 0) -> Optional[Dict]: + def show_services(self, name: str, node_index: int = 0) -> Optional[Dict[str, ServiceInstance]]: """Gets services associated with a job's node. :param name: The name of the job. diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py index c35b723883e0..2a380c39f820 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py @@ -72,13 +72,15 @@ def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Registry]: ) @monitor_with_activity(ops_logger, "Registry.Get", ActivityType.PUBLICAPI) - def get(self, name: Optional[str] = None) -> Optional[Registry]: + def get(self, name: Optional[str] = None) -> Registry: """Get a registry by name. :param name: Name of the registry. :type name: str :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Registry name cannot be successfully validated. Details will be provided in the error message. + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. :return: The registry with the provided name. :rtype: ~azure.ai.ml.entities.Registry """ @@ -86,7 +88,7 @@ def get(self, name: Optional[str] = None) -> Optional[Registry]: registry_name = self._check_registry_name(name) resource_group = self._resource_group_name obj = self._operation.get(resource_group, registry_name) - return Registry._from_rest_object(obj) + return Registry._from_rest_object(obj) # type: ignore[return-value] def _check_registry_name(self, name: Optional[str]) -> str: registry_name = name or self._default_registry_name diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py index 41e6049b6206..82e7e24d4b7b 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_connections_operations.py @@ -47,11 +47,13 @@ def __init__( self._init_kwargs = kwargs @monitor_with_activity(ops_logger, "WorkspaceConnections.Get", ActivityType.PUBLICAPI) - def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceConnection]: + def get(self, name: str, **kwargs: Dict) -> WorkspaceConnection: """Get a workspace connection by name. :param name: Name of the workspace connection. :type name: str + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. :return: The workspace connection with the provided name. :rtype: ~azure.ai.ml.entities.WorkspaceConnection @@ -72,7 +74,7 @@ def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceConnection]: **kwargs, ) - return WorkspaceConnection._from_rest_object(rest_obj=obj) + return WorkspaceConnection._from_rest_object(rest_obj=obj) # type: ignore[return-value] @monitor_with_activity(ops_logger, "WorkspaceConnections.CreateOrUpdate", ActivityType.PUBLICAPI) def create_or_update( diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_hub_operation.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_hub_operation.py index db488b876961..0786a474c44b 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_hub_operation.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_hub_operation.py @@ -92,11 +92,13 @@ def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[WorkspaceHub]: @distributed_trace @monitor_with_activity(ops_logger, "WorkspaceHub.Get", ActivityType.PUBLICAPI) # pylint: disable=arguments-renamed, arguments-differ - def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceHub]: + def get(self, name: str, **kwargs: Dict) -> WorkspaceHub: """Get a Workspace WorkspaceHub by name. :param name: Name of the WorkspaceHub. :type name: str + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. :return: The WorkspaceHub with the provided name. :rtype: ~azure.ai.ml.entities.WorkspaceHub @@ -116,7 +118,7 @@ def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceHub]: if rest_workspace_obj and rest_workspace_obj.kind and rest_workspace_obj.kind.lower() == WORKSPACE_HUB_KIND: workspace_hub = WorkspaceHub._from_rest_object(rest_workspace_obj) - return workspace_hub + return workspace_hub # type: ignore[return-value] @distributed_trace @monitor_with_activity(ops_logger, "WorkspaceHub.BeginCreate", ActivityType.PUBLICAPI)