Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dacfbc1
Revert change to acs_connection_id type
diondrapeck Jan 24, 2024
a9fc28f
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Jan 25, 2024
594794d
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Jan 25, 2024
9ed7442
Resolve merge conflict
diondrapeck Jan 29, 2024
9495639
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Jan 31, 2024
eb4669d
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 5, 2024
a1bc7b3
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 5, 2024
12a43ef
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 7, 2024
16bebb4
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 8, 2024
e5ee79f
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 20, 2024
e9fdd4f
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 21, 2024
04f17be
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 21, 2024
015db3d
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Feb 22, 2024
91abebd
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Mar 1, 2024
bddcb84
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Mar 1, 2024
a0982e0
Address remaining type issues
diondrapeck Mar 2, 2024
5263c6a
Revert changes to irrelevant files
diondrapeck Mar 2, 2024
dd75b4b
Merge branch 'main' of https://github.com/Azure/azure-sdk-for-python
diondrapeck Mar 2, 2024
9363753
Resolve merge conflict
diondrapeck Mar 2, 2024
804c3b5
Add whitespace back
diondrapeck Mar 2, 2024
26da46a
Remove unintentional change
diondrapeck Mar 2, 2024
c1dd47f
Remove Optional type annotation from get() methods
diondrapeck Mar 4, 2024
48961be
Remove duplicate overloads
diondrapeck Mar 4, 2024
a6b3627
Fix circular import
diondrapeck Mar 4, 2024
f8bbff6
Revert lost input/output overload types
diondrapeck Mar 4, 2024
a57fd06
Remove duplicate overloads for uri_folder
diondrapeck Mar 4, 2024
7770f53
Resolve merge conflict
diondrapeck Mar 4, 2024
2ae891a
Resolve merge conflict
diondrapeck Mar 4, 2024
3861412
Use type:ignore to silence known incompatibilities
diondrapeck Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import re
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast

from marshmallow import INCLUDE, Schema
Expand All @@ -23,6 +24,7 @@
)
from azure.ai.ml.entities._job.job import Job
from azure.ai.ml.entities._job.parallel.run_function import RunFunction
from azure.ai.ml.entities._job.pipeline._io import NodeOutput

from ..._schema import PathAwareSchema
from ..._utils.utils import is_data_binding_expression
Expand All @@ -31,7 +33,7 @@
from .._component.component import Component
from .._component.flow import FlowComponent
from .._component.parallel_component import ParallelComponent
from .._inputs_outputs import Output
from .._inputs_outputs import Input, Output
from .._job.job_resource_configuration import JobResourceConfiguration
from .._job.parallel.parallel_job import ParallelJob
from .._job.parallel.parallel_task import ParallelTask
Expand Down Expand Up @@ -108,7 +110,7 @@ def __init__(
*,
component: Union[ParallelComponent, str],
compute: Optional[str] = None,
inputs: Optional[Dict] = None,
inputs: Optional[Dict[str, Union[NodeOutput, Input, str, bool, int, float, Enum]]] = None,
outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None,
retry_settings: Optional[Union[RetrySettings, Dict[str, str]]] = None,
logging_level: Optional[str] = None,
Expand Down
4 changes: 2 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@ def __setattr__(self, key: Any, value: Any) -> None:
# only one of slack_amount and slack_factor can be specified but default value is 0.0.
# Need to keep track of which one is null.
if self.early_termination.slack_amount == 0.0:
self.early_termination.slack_amount = None
self.early_termination.slack_amount = None # type: ignore[assignment]
if self.early_termination.slack_factor == 0.0:
self.early_termination.slack_factor = None
self.early_termination.slack_factor = None # type: ignore[assignment]

@property
def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class AzureFileDatastore(Datastore):
:param properties: The asset property dictionary.
:type properties: dict[str, str]
:param credentials: Credentials to use for Azure ML workspace to connect to the storage.
:type credentials: Union[AccountKeySection, SasSection]
:type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration,
~azure.ai.ml.entities.SasTokenConfiguration]
:param kwargs: A dictionary of additional configuration parameters.
:type kwargs: dict
"""
Expand All @@ -61,7 +62,7 @@ def __init__(
endpoint: str = _get_storage_endpoint_from_metadata(),
protocol: str = HTTPS,
properties: Optional[Dict] = None,
credentials: Any,
credentials: Union[AccountKeyConfiguration, SasTokenConfiguration],
**kwargs: Any
):
kwargs[TYPE] = DatastoreType.AZURE_FILE
Expand Down Expand Up @@ -97,7 +98,7 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureFileDatas
name=datastore_resource.name,
id=datastore_resource.id,
account_name=properties.account_name,
credentials=from_rest_datastore_credentials(properties.credentials),
credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
endpoint=properties.endpoint,
protocol=properties.protocol,
file_share_name=properties.file_share_name,
Expand Down Expand Up @@ -260,7 +261,7 @@ def __init__(
endpoint: str = _get_storage_endpoint_from_metadata(),
protocol: str = HTTPS,
properties: Optional[Dict] = None,
credentials: Any = None,
credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None,
**kwargs: Any
):
kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN2
Expand Down Expand Up @@ -299,7 +300,7 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeG
name=datastore_resource.name,
id=datastore_resource.id,
account_name=properties.account_name,
credentials=from_rest_datastore_credentials(properties.credentials),
credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
endpoint=properties.endpoint,
protocol=properties.protocol,
filesystem=properties.filesystem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
query: Optional[str] = None,
table_name: Optional[str] = None,
stored_procedure: Optional[str] = None,
stored_procedure_params: Optional[List] = None,
stored_procedure_params: Optional[List[Dict]] = None,
connection: Optional[str] = None,
) -> None:
# As an annotation, it is not allowed to initialize the name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,6 @@ def __init__(
) -> None:
""""""

@overload
def __init__(
self,
*,
type: Literal["uri_folder"] = "uri_folder",
path: Optional[str] = None,
mode: Optional[str] = None,
optional: Optional[bool] = None,
description: Optional[str] = None,
**kwargs: Any,
) -> None:
""""""

@overload
def __init__(
self,
Expand Down
77 changes: 32 additions & 45 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,42 @@ def __init__(

@overload
def __init__(
self,
type: Literal["uri_file"] = "uri_file",
path: Optional[str] = None,
mode: Optional[str] = None,
description: Optional[str] = None,
):
"""Define a URI file output.

:keyword type: The type of the data output. Can only be set to 'uri_file'.
:paramtype type: str
:keyword path: The remote path where the output should be stored.
:paramtype path: str
:keyword mode: The access mode of the data output. Accepted values are
* 'rw_mount': Read-write mount the data,
* 'upload': Upload the data from the compute target,
* 'direct': Pass in the URI as a string
:paramtype mode: str
:keyword description: The description of the output.
:paramtype description: str
:keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without
setting a version.
:paramtype name: str
:keyword version: The version used to register the output as a Data or Model asset. A version can be set only
when name is set.
:paramtype version: str
"""

def __init__( # type: ignore[misc]
self,
*,
type: Literal["uri_folder"] = "uri_folder",
type: str = AssetTypes.URI_FOLDER,
path: Optional[str] = None,
mode: Optional[str] = None,
description: Optional[str] = None,
**kwargs: Any,
) -> None:
# pylint: disable=line-too-long
"""Define an output.

:keyword type: The type of the data output. Accepted values are 'uri_folder', 'uri_file', 'mltable',
Expand All @@ -68,9 +96,8 @@ def __init__(
:paramtype early_available: bool
:keyword intellectual_property: Intellectual property associated with the output.
It can be an instance of `IntellectualProperty` or a dictionary that will be used to create an instance.
:paramtype intellectual_property: Union[~azure.ai.ml.entities._assets.intellectual_property.IntellectualProperty,

dict]
:paramtype intellectual_property: Union[
~azure.ai.ml.entities._assets.intellectual_property.IntellectualProperty, dict]

.. admonition:: Example:

Expand All @@ -81,46 +108,6 @@ def __init__(
:dedent: 8
:caption: Creating a CommandJob with a folder output.
"""

@overload
def __init__(
self,
*,
type: Literal["uri_file"] = "uri_file",
path: Optional[str] = None,
mode: Optional[str] = None,
description: Optional[str] = None,
):
"""Define a URI file output.

:keyword type: The type of the data output. Can only be set to 'uri_file'.
:paramtype type: str
:keyword path: The remote path where the output should be stored.
:paramtype path: str
:keyword mode: The access mode of the data output. Accepted values are
* 'rw_mount': Read-write mount the data,
* 'upload': Upload the data from the compute target,
* 'direct': Pass in the URI as a string
:paramtype mode: str
:keyword description: The description of the output.
:paramtype description: str
:keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without
setting a version.
:paramtype name: str
:keyword version: The version used to register the output as a Data or Model asset. A version can be set only
when name is set.
:paramtype version: str
"""

def __init__(
self,
*,
type: str = AssetTypes.URI_FOLDER,
path: Optional[str] = None,
mode: Optional[str] = None,
description: Optional[str] = None,
**kwargs: Any,
) -> None:
super(Output, self).__init__(type=type)
# As an annotation, it is not allowed to initialize the _port_name.
self._port_name = None
Expand Down
6 changes: 3 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def status(self) -> Optional[str]:
return self._status

@property
def log_files(self) -> Optional[Dict]:
def log_files(self) -> Optional[Dict[str, str]]:
"""Job output files.

:return: The dictionary of log names and URLs.
Expand Down Expand Up @@ -208,10 +208,10 @@ def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]]
from azure.ai.ml.entities._builders.command import Command
from azure.ai.ml.entities._builders.spark import Spark
from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
from azure.ai.ml.entities._job.import_job import ImportJob
from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob
from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob
from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob

job_type: Optional[Type["Job"]] = None
type_in_override = find_type_in_override(params_override)
Expand Down Expand Up @@ -293,9 +293,9 @@ def _from_rest_object( # pylint: disable=too-many-return-statements
from azure.ai.ml.entities._builders.spark import Spark
from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
from azure.ai.ml.entities._job.base_job import _BaseJob
from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
from azure.ai.ml.entities._job.import_job import ImportJob
from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob
from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob

try:
if isinstance(obj, Run):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Parallel":
component=component,
compute=self.compute,
# Need to supply the inputs with double curly.
inputs=self.inputs,
inputs=self.inputs, # type: ignore[arg-type]
outputs=self.outputs, # type: ignore[arg-type]
mini_batch_size=self.mini_batch_size,
partition_keys=self.partition_keys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(
*,
delay_evaluation: int = 0,
evaluation_interval: int = 0,
slack_amount: Optional[float] = 0,
slack_factor: Optional[float] = 0,
slack_amount: float = 0,
slack_factor: float = 0,
) -> None:
super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval)
self.type = EarlyTerminationPolicyType.BANDIT.lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _get(self, name: str, version: Optional[str] = None, **kwargs: Dict) -> Feat

@distributed_trace
@monitor_with_activity(ops_logger, "FeatureSet.Get", ActivityType.PUBLICAPI)
def get(self, name: str, version: str, **kwargs: Dict) -> Optional[FeatureSet]: # type: ignore
def get(self, name: str, version: str, **kwargs: Dict) -> FeatureSet: # type: ignore
"""Get the specified FeatureSet asset.

:param name: Name of FeatureSet asset.
Expand All @@ -134,6 +134,8 @@ def get(self, name: str, version: str, **kwargs: Dict) -> Optional[FeatureSet]:
:type version: str
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSet cannot be successfully
identified and retrieved. Details will be provided in the error message.
:raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
retrieved from the service.
:return: FeatureSet asset object.
:rtype: ~azure.ai.ml.entities.FeatureSet
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def get(self, name: str, **kwargs: Any) -> FeatureStore:

:param name: Name of the feature store.
:type name: str
:raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
retrieved from the service.
:return: The feature store with the provided name.
:rtype: FeatureStore
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def get(self, name: str) -> Job:

@distributed_trace
@monitor_with_telemetry_mixin(ops_logger, "Job.ShowServices", ActivityType.PUBLICAPI)
def show_services(self, name: str, node_index: int = 0) -> Optional[Dict]:
def show_services(self, name: str, node_index: int = 0) -> Optional[Dict[str, ServiceInstance]]:
"""Gets services associated with a job's node.

:param name: The name of the job.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,23 @@ def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Registry]:
)

@monitor_with_activity(ops_logger, "Registry.Get", ActivityType.PUBLICAPI)
def get(self, name: Optional[str] = None) -> Optional[Registry]:
def get(self, name: Optional[str] = None) -> Registry:
"""Get a registry by name.

:param name: Name of the registry.
:type name: str
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if Registry name cannot be
successfully validated. Details will be provided in the error message.
:raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
retrieved from the service.
:return: The registry with the provided name.
:rtype: ~azure.ai.ml.entities.Registry
"""

registry_name = self._check_registry_name(name)
resource_group = self._resource_group_name
obj = self._operation.get(resource_group, registry_name)
return Registry._from_rest_object(obj)
return Registry._from_rest_object(obj) # type: ignore[return-value]

def _check_registry_name(self, name: Optional[str]) -> str:
registry_name = name or self._default_registry_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ def __init__(
self._init_kwargs = kwargs

@monitor_with_activity(ops_logger, "WorkspaceConnections.Get", ActivityType.PUBLICAPI)
def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceConnection]:
def get(self, name: str, **kwargs: Dict) -> WorkspaceConnection:
"""Get a workspace connection by name.

:param name: Name of the workspace connection.
:type name: str
:raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
retrieved from the service.
:return: The workspace connection with the provided name.
:rtype: ~azure.ai.ml.entities.WorkspaceConnection

Expand All @@ -72,7 +74,7 @@ def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceConnection]:
**kwargs,
)

return WorkspaceConnection._from_rest_object(rest_obj=obj)
return WorkspaceConnection._from_rest_object(rest_obj=obj) # type: ignore[return-value]

@monitor_with_activity(ops_logger, "WorkspaceConnections.CreateOrUpdate", ActivityType.PUBLICAPI)
def create_or_update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[WorkspaceHub]:
@distributed_trace
@monitor_with_activity(ops_logger, "WorkspaceHub.Get", ActivityType.PUBLICAPI)
# pylint: disable=arguments-renamed, arguments-differ
def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceHub]:
def get(self, name: str, **kwargs: Dict) -> WorkspaceHub:
"""Get a Workspace WorkspaceHub by name.

:param name: Name of the WorkspaceHub.
:type name: str
:raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
retrieved from the service.
:return: The WorkspaceHub with the provided name.
:rtype: ~azure.ai.ml.entities.WorkspaceHub

Expand All @@ -116,7 +118,7 @@ def get(self, name: str, **kwargs: Dict) -> Optional[WorkspaceHub]:
if rest_workspace_obj and rest_workspace_obj.kind and rest_workspace_obj.kind.lower() == WORKSPACE_HUB_KIND:
workspace_hub = WorkspaceHub._from_rest_object(rest_workspace_obj)

return workspace_hub
return workspace_hub # type: ignore[return-value]

@distributed_trace
@monitor_with_activity(ops_logger, "WorkspaceHub.BeginCreate", ActivityType.PUBLICAPI)
Expand Down