Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
tags=tags,
properties=properties,
display_name=display_name,
is_deterministic=is_deterministic,
is_deterministic=is_deterministic, # type: ignore[arg-type]
inputs=inputs,
outputs=outputs,
yaml_str=yaml_str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,18 @@ def _local_scoring_script_is_valid(deployment: OnlineDeployment):
def _code_configuration_contains_cloud_artifacts(deployment: OnlineDeployment):
# If the deployment.code_configuration.code is a string, then it is the cloud code artifact name or full arm ID

return isinstance(deployment.code_configuration.code, str) and (
is_url(deployment.code_configuration.code) or deployment.code_configuration.code.startswith(ARM_ID_PREFIX)
return isinstance(deployment.code_configuration.code, str) and ( # type: ignore[union-attr]
is_url(deployment.code_configuration.code) # type: ignore[union-attr]
or deployment.code_configuration.code.startswith(ARM_ID_PREFIX) # type: ignore[union-attr]
)


def _get_local_code_configuration_artifacts(
deployment: OnlineDeployment,
) -> Path:
return Path(deployment._base_path, deployment.code_configuration.code).resolve()
return Path(
deployment._base_path, deployment.code_configuration.code # type: ignore[union-attr, arg-type]
).resolve()


def _get_cloud_code_configuration_artifacts(code: str, code_operations: CodeOperations, download_path: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
type: Optional[str] = None, # pylint: disable=redefined-builtin
path: Optional[Union[str, PathLike]] = None,
utc_time_created: Optional[str] = None,
flavors: Optional[Dict] = None,
flavors: Optional[Dict[str, Dict[str, Any]]] = None,
description: Optional[str] = None,
tags: Optional[Dict] = None,
properties: Optional[Dict] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
**kwargs: Any,
):
self._arm_type: str = ""
self.latest_version: str = ""
self.latest_version: str = "" # type: ignore[assignment]
self.image: Optional[str] = None
inference_config = kwargs.pop("inference_config", None)
os_type = kwargs.pop("os_type", None)
Expand Down
22 changes: 12 additions & 10 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

import copy
import logging
import os
Expand Down Expand Up @@ -170,7 +168,9 @@ def __init__(
environment: Optional[Union[Environment, str]] = None,
environment_variables: Optional[Dict] = None,
resources: Optional[JobResourceConfiguration] = None,
services: Optional[Dict] = None,
services: Optional[
Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
] = None,
queue_settings: Optional[QueueSettings] = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(
self.queue_settings = queue_settings

if isinstance(self.component, CommandComponent):
self.resources = self.resources or self.component.resources
self.resources = self.resources or self.component.resources # type: ignore[assignment]
self.distribution = self.distribution or self.component.distribution

self._swept: bool = False
Expand Down Expand Up @@ -277,12 +277,12 @@ def distribution(
self._distribution = value

@property
def resources(self) -> Any:
def resources(self) -> JobResourceConfiguration:
"""The compute resource configuration for the command component or job.

:rtype: ~azure.ai.ml.entities.JobResourceConfiguration
"""
return self._resources
return cast(JobResourceConfiguration, self._resources)

@resources.setter
def resources(self, value: Union[Dict, JobResourceConfiguration]) -> None:
Expand Down Expand Up @@ -381,10 +381,10 @@ def services(
~azure.ai.ml.entities.SshJobService, ~azure.ai.ml.entities.TensorBoardJobService,
~azure.ai.ml.entities.VsCodeJobService]]
"""
self._services = _resolve_job_services(value)
self._services = _resolve_job_services(value) # type: ignore[assignment]

@property
def component(self) -> Any:
def component(self) -> Union[str, CommandComponent]:
"""The ID or instance of the command component or job to be run for the step.

:return: The ID or instance of the command component or job to be run for the step.
Expand Down Expand Up @@ -855,7 +855,9 @@ def _load_from_rest_job(cls, obj: JobBase) -> "Command":
outputs=from_rest_data_outputs(rest_command_job.outputs),
)
command_job._id = obj.id
command_job.resources = JobResourceConfiguration._from_rest_object(rest_command_job.resources)
command_job.resources = cast(
JobResourceConfiguration, JobResourceConfiguration._from_rest_object(rest_command_job.resources)
)
command_job.limits = CommandJobLimits._from_rest_object(rest_command_job.limits)
command_job.queue_settings = QueueSettings._from_rest_object(rest_command_job.queue_settings)
if isinstance(command_job.component, CommandComponent):
Expand Down Expand Up @@ -939,7 +941,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "Command":


@overload
def _resolve_job_services(services: None) -> None:
def _resolve_job_services(services: Optional[Dict]):
...


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
compute: Optional[str] = None,
inputs: Optional[Dict] = None,
outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None,
retry_settings: Optional[Union[RetrySettings, Dict]] = None,
retry_settings: Optional[Union[RetrySettings, Dict[str, str]]] = None,
logging_level: Optional[str] = None,
max_concurrency_per_instance: Optional[int] = None,
error_threshold: Optional[int] = None,
Expand All @@ -132,7 +132,8 @@ def __init__(

if isinstance(component, FlowComponent):
# make input definition fit actual inputs for flow component
with component._inputs._fit_inputs(inputs): # pylint: disable=protected-access
# pylint: disable=protected-access
with component._inputs._fit_inputs(inputs): # type: ignore[attr-defined]
BaseNode.__init__(
self,
type=NodeType.PARALLEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ManagedIdentityConfiguration,
UserIdentityConfiguration,
)
from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings
from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings
from azure.ai.ml.entities._job.parallel.run_function import RunFunction

from .command_func import _parse_input, _parse_inputs_outputs, _parse_output
Expand All @@ -27,7 +27,7 @@ def parallel_run_function(
display_name: Optional[str] = None,
experiment_name: Optional[str] = None,
compute: Optional[str] = None,
retry_settings: Optional[RetrySettings] = None,
retry_settings: Optional[BatchRetrySettings] = None,
environment_variables: Optional[Dict] = None,
logging_level: Optional[str] = None,
max_concurrency_per_instance: Optional[int] = None,
Expand Down Expand Up @@ -215,7 +215,7 @@ def parallel_run_function(
description=description,
inputs=component_inputs,
outputs=component_outputs,
retry_settings=retry_settings,
retry_settings=retry_settings, # type: ignore[arg-type]
logging_level=logging_level,
max_concurrency_per_instance=max_concurrency_per_instance,
error_threshold=error_threshold,
Expand All @@ -238,7 +238,7 @@ def parallel_run_function(
description=description,
inputs=component_inputs,
outputs=component_outputs,
retry_settings=retry_settings,
retry_settings=retry_settings, # type: ignore[arg-type]
logging_level=logging_level,
max_concurrency_per_instance=max_concurrency_per_instance,
error_threshold=error_threshold,
Expand Down Expand Up @@ -266,7 +266,7 @@ def parallel_run_function(
outputs=job_outputs,
identity=identity,
environment_variables=environment_variables,
retry_settings=retry_settings,
retry_settings=retry_settings, # type: ignore[arg-type]
logging_level=logging_level,
max_concurrency_per_instance=max_concurrency_per_instance,
error_threshold=error_threshold,
Expand Down
26 changes: 19 additions & 7 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,14 @@ def __init__(
early_termination: Optional[
Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy, EarlyTerminationPolicy, str]
] = None,
search_space: Optional[Dict] = None,
search_space: Optional[
Dict[
str,
Union[
Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
],
]
] = None,
inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
outputs: Optional[Dict[str, Union[str, Output]]] = None,
identity: Optional[
Expand Down Expand Up @@ -191,7 +198,12 @@ def trial(self) -> CommandComponent:
@property
def search_space(
self,
) -> Optional[Dict]:
) -> Optional[
Dict[
str,
Union[Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform],
]
]:
"""Dictionary of the hyperparameter search space.

Each key is the name of a hyperparameter and its value is the parameter expression.
Expand All @@ -210,7 +222,7 @@ def search_space(self, values: Dict[str, Dict[str, Union[str, int, float, dict]]
:param values: The search space to set.
:type values: Dict[str, Dict[str, Union[str, int, float, dict]]]
"""
search_space = {}
search_space: Dict = {}
for name, value in values.items():
# If value is a SearchSpace object, directly pass it to job.search_space[name]
search_space[name] = self._value_type_to_class(value) if isinstance(value, dict) else value
Expand Down Expand Up @@ -341,7 +353,7 @@ def _to_job(self) -> SweepJob:
sampling_algorithm=self.sampling_algorithm,
search_space=self.search_space,
limits=self.limits,
early_termination=self.early_termination,
early_termination=self.early_termination, # type: ignore[arg-type]
objective=self.objective,
inputs=self._job_inputs,
outputs=self._job_outputs,
Expand Down Expand Up @@ -420,7 +432,7 @@ def __setattr__(self, key: Any, value: Any) -> None:
self.early_termination.slack_factor = None

@property
def early_termination(self) -> Optional[EarlyTerminationPolicy]:
def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]:
"""The early termination policy for the sweep job.

:rtype: Union[str, ~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy,
Expand All @@ -429,7 +441,7 @@ def early_termination(self) -> Optional[EarlyTerminationPolicy]:
return self._early_termination

@early_termination.setter
def early_termination(self, value: Optional[EarlyTerminationPolicy]) -> None:
def early_termination(self, value: Optional[Union[str, EarlyTerminationPolicy]]) -> None:
"""Sets the early termination policy for the sweep job.

:param value: The early termination policy for the sweep job.
Expand All @@ -439,4 +451,4 @@ def early_termination(self, value: Optional[EarlyTerminationPolicy]) -> None:
if isinstance(value, dict):
early_termination_schema = EarlyTerminationField()
value = early_termination_schema._deserialize(value=value, attr=None, data=None)
self._early_termination = value
self._early_termination = value # type: ignore[assignment]
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
DistributionConfiguration,
]
] = None,
resources: Optional[Union[Dict, JobResourceConfiguration]] = None,
resources: Optional[JobResourceConfiguration] = None,
inputs: Optional[Dict] = None,
outputs: Optional[Dict] = None,
instance_count: Optional[int] = None, # promoted property from resources.instance_count
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(
self.code = code
self.environment_variables = environment_variables
self.environment = environment
self.resources = resources
self.resources = resources # type: ignore[assignment]
self.distribution = distribution

# check mutual exclusivity of promoted properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
tags: Optional[Dict] = None,
properties: Optional[Dict] = None,
display_name: Optional[str] = None,
is_deterministic: Optional[bool] = True,
is_deterministic: bool = True,
inputs: Optional[Dict] = None,
outputs: Optional[Dict] = None,
yaml_str: Optional[str] = None,
Expand Down Expand Up @@ -423,7 +423,7 @@ def _from_container_rest_object(cls, component_container_rest_object: ComponentC
properties=component_container_details.properties,
type=NodeType._CONTAINER,
# Set this field to None as it hold a default True in init.
is_deterministic=None,
is_deterministic=None, # type: ignore[arg-type]
)
component.latest_version = component_container_details.latest_version
return component
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _to_rest_object(self) -> ComponentVersion:
def _func(self, **kwargs: Any) -> "Parallel": # pylint: disable=invalid-overridden-method
from azure.ai.ml.entities._builders.parallel import Parallel

with self._inputs._fit_inputs(kwargs): # pylint: disable=protected-access
with self._inputs._fit_inputs(kwargs): # type: ignore[attr-defined]
# pylint: disable=not-callable
return super()._func(**kwargs) # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
display_name=display_name,
inputs=inputs,
outputs=outputs,
is_deterministic=is_deterministic,
is_deterministic=is_deterministic, # type: ignore[arg-type]
**kwargs,
)
self._jobs = self._process_jobs(jobs) if jobs else {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# pylint: disable=protected-access

import re
from typing import Optional
from typing import Optional, cast

from azure.ai.ml._restclient.v2022_10_01_preview.models import ScriptReference as RestScriptReference
from azure.ai.ml._restclient.v2022_10_01_preview.models import ScriptsToExecute as RestScriptsToExecute
Expand All @@ -24,7 +24,7 @@ class ScriptReference(RestTranslatableMixin):
"""

def __init__(
self, *, path: Optional[str] = None, command: Optional[str] = None, timeout_minutes: Optional[str] = None
self, *, path: Optional[str] = None, command: Optional[str] = None, timeout_minutes: Optional[int] = None
) -> None:
self.path = path
self.command = command
Expand All @@ -47,7 +47,7 @@ def _from_rest_object(cls, obj: RestScriptReference) -> Optional["ScriptReferenc
script_reference = ScriptReference(
path=obj.script_data if obj.script_data else None,
command=obj.script_arguments if obj.script_arguments else None,
timeout_minutes=timeout_minutes,
timeout_minutes=cast(Optional[int], timeout_minutes),
)
return script_reference

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# pylint: disable=protected-access,no-member

from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from azure.ai.ml._restclient.v2023_04_01_preview.models import (
AzureDataLakeGen1Datastore as RestAzureDatalakeGen1Datastore,
Expand All @@ -14,6 +14,7 @@
from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType
from azure.ai.ml._schema._datastore.adls_gen1 import AzureDataLakeGen1Schema
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
from azure.ai.ml.entities._credentials import CertificateConfiguration, ServicePrincipalConfiguration
from azure.ai.ml.entities._datastore.datastore import Datastore
from azure.ai.ml.entities._datastore.utils import from_rest_datastore_credentials
from azure.ai.ml.entities._util import load_from_dict
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
description: Optional[str] = None,
tags: Optional[Dict] = None,
properties: Optional[Dict] = None,
credentials: Any = None,
credentials: Optional[Union[CertificateConfiguration, ServicePrincipalConfiguration]] = None,
**kwargs: Any
):
kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN1
Expand Down Expand Up @@ -81,7 +82,7 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeG
id=datastore_resource.id,
name=datastore_resource.name,
store_name=properties.store_name,
credentials=from_rest_datastore_credentials(properties.credentials),
credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
description=properties.description,
tags=properties.tags,
)
Expand Down
Loading