Skip to content
Merged
Prev Previous commit
Next Next commit
fix mypy errors - 4
  • Loading branch information
CoderKevinZhang committed Feb 28, 2024
commit 9f8c5b4c67800cfde1454d2ad3eac3f42b1f3774
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
compute: Optional[str] = None,
inputs: Optional[Dict] = None,
outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None,
retry_settings: Optional[Union[RetrySettings, Dict]] = None,
retry_settings: Optional[Union[RetrySettings, Dict[str, str]]] = None,
logging_level: Optional[str] = None,
max_concurrency_per_instance: Optional[int] = None,
error_threshold: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ManagedIdentityConfiguration,
UserIdentityConfiguration,
)
from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings
from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings
from azure.ai.ml.entities._job.parallel.run_function import RunFunction

from .command_func import _parse_input, _parse_inputs_outputs, _parse_output
Expand All @@ -27,7 +27,7 @@ def parallel_run_function(
display_name: Optional[str] = None,
experiment_name: Optional[str] = None,
compute: Optional[str] = None,
retry_settings: Optional[RetrySettings] = None,
retry_settings: Optional[BatchRetrySettings] = None,
environment_variables: Optional[Dict] = None,
logging_level: Optional[str] = None,
max_concurrency_per_instance: Optional[int] = None,
Expand Down Expand Up @@ -215,7 +215,7 @@ def parallel_run_function(
description=description,
inputs=component_inputs,
outputs=component_outputs,
retry_settings=retry_settings,
retry_settings=retry_settings, # type: ignore[arg-type]
logging_level=logging_level,
max_concurrency_per_instance=max_concurrency_per_instance,
error_threshold=error_threshold,
Expand All @@ -238,7 +238,7 @@ def parallel_run_function(
description=description,
inputs=component_inputs,
outputs=component_outputs,
retry_settings=retry_settings,
retry_settings=retry_settings, # type: ignore[arg-type]
logging_level=logging_level,
max_concurrency_per_instance=max_concurrency_per_instance,
error_threshold=error_threshold,
Expand Down Expand Up @@ -266,7 +266,7 @@ def parallel_run_function(
outputs=job_outputs,
identity=identity,
environment_variables=environment_variables,
retry_settings=retry_settings,
retry_settings=retry_settings, # type: ignore[arg-type]
logging_level=logging_level,
max_concurrency_per_instance=max_concurrency_per_instance,
error_threshold=error_threshold,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ScriptReference(RestTranslatableMixin):
"""

def __init__(
self, *, path: Optional[str] = None, command: Optional[str] = None, timeout_minutes: Optional[str] = None
self, *, path: Optional[str] = None, command: Optional[str] = None, timeout_minutes: Optional[int] = None
) -> None:
self.path = path
self.command = command
Expand All @@ -47,7 +47,7 @@ def _from_rest_object(cls, obj: RestScriptReference) -> Optional["ScriptReferenc
script_reference = ScriptReference(
path=obj.script_data if obj.script_data else None,
command=obj.script_arguments if obj.script_arguments else None,
timeout_minutes=timeout_minutes,
timeout_minutes=timeout_minutes, # type: ignore[arg-type]
)
return script_reference

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ParameterizedCommand:

def __init__(
self,
command: Optional[str] = "",
command: Union[str, None] = "",
resources: Optional[Union[dict, JobResourceConfiguration]] = None,
code: Optional[Union[str, os.PathLike]] = None,
environment_variables: Optional[Dict] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
intellectual_property: Optional[IntellectualProperty] = None,
managed_resource_group: Optional[str] = None,
mlflow_registry_uri: Optional[str] = None,
replication_locations: Optional[List],
replication_locations: Optional[List[RegistryRegionDetails]],
**kwargs: Any,
):
"""Azure ML registry.
Expand Down Expand Up @@ -115,7 +115,7 @@ def _to_dict(self) -> Dict:
# limited, and would probably just confuse most users.
if self.replication_locations and len(self.replication_locations) > 0:
if self.replication_locations[0].acr_config and len(self.replication_locations[0].acr_config) > 0:
self.container_registry = self.replication_locations[0].acr_config[0]
self.container_registry = self.replication_locations[0].acr_config[0] # type: ignore[assignment]

res: dict = schema.dump(self)
return res
Expand Down Expand Up @@ -170,7 +170,7 @@ def _from_rest_object(cls, rest_obj: RestRegistry) -> Optional["Registry"]:
else None,
managed_resource_group=real_registry.managed_resource_group,
mlflow_registry_uri=real_registry.ml_flow_registry_uri,
replication_locations=replication_locations,
replication_locations=replication_locations, # type: ignore[arg-type]
)

# There are differences between what our registry validation schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class RegistryRegionDetails:
def __init__(
self,
*,
acr_config: Optional[List] = None,
acr_config: Optional[List[Union[str, SystemCreatedAcrAccount]]] = None,
location: Optional[str] = None,
storage_config: Optional[Union[List[str], SystemCreatedStorageAccount]] = None,
):
Expand Down Expand Up @@ -161,7 +161,9 @@ def _from_rest_object(cls, rest_obj: RestRegistryRegionArmDetails) -> Optional["
storages = cls._storage_config_from_rest_object(rest_obj.storage_account_details)

return RegistryRegionDetails(
acr_config=converted_acr_details, location=rest_obj.location, storage_config=storages
acr_config=converted_acr_details, # type: ignore[arg-type]
location=rest_obj.location,
storage_config=storages,
)

def _to_rest_object(self) -> RestRegistryRegionArmDetails:
Expand Down
7 changes: 3 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_system_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from datetime import datetime
from typing import Union

from typing import Any

from azure.ai.ml._restclient.v2022_10_01.models import SystemData as RestSystemData
from azure.ai.ml.entities import CreatedByType
from azure.ai.ml.entities._mixins import RestTranslatableMixin


Expand Down Expand Up @@ -43,7 +42,7 @@ class SystemData(RestTranslatableMixin):
:paramtype last_modified_at: datetime
"""

def __init__(self, **kwargs: Union[str, CreatedByType, datetime]) -> None:
def __init__(self, **kwargs: Any) -> None:
self.created_by = kwargs.get("created_by", None)
self.created_by_type = kwargs.get("created_by_type", None)
self.created_at = kwargs.get("created_at", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient022023Preview
from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, ListViewType, UserIdentity
from azure.ai.ml._restclient.v2023_08_01_preview.models import JobType as RestJobType
from azure.ai.ml._restclient.v2024_01_01_preview.models import JobType as RestJobType_20240101
from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase as JobBase_2401
from azure.ai.ml._restclient.v2024_01_01_preview.models import JobType as RestJobType_20240101
from azure.ai.ml._scope_dependent_operations import (
OperationConfig,
OperationsContainer,
Expand Down Expand Up @@ -72,8 +72,8 @@
from azure.ai.ml.entities._datastore._constants import WORKSPACE_BLOB_STORE
from azure.ai.ml.entities._inputs_outputs import Input
from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
from azure.ai.ml.entities._job.base_job import _BaseJob
from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
from azure.ai.ml.entities._job.import_job import ImportJob
from azure.ai.ml.entities._job.job import _is_pipeline_child_job
from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob
Expand Down Expand Up @@ -365,7 +365,7 @@ def get(self, name: str) -> Job:

@distributed_trace
@monitor_with_telemetry_mixin(logger, "Job.ShowServices", ActivityType.PUBLICAPI)
def show_services(self, name: str, node_index: int = 0) -> Optional[Dict]:
def show_services(self, name: str, node_index: int = 0) -> Optional[Dict[str, ServiceInstance]]:
"""Gets services associated with a job's node.

:param name: The name of the job.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def invoke(
self,
endpoint_name: str,
*,
request_file: Any = None,
request_file: Optional[str] = None,
deployment_name: Optional[str] = None,
# pylint: disable=unused-argument
input_data: Optional[Union[str, Data]] = None,
Expand Down Expand Up @@ -329,7 +329,7 @@ def invoke(
if deployment_name:
self._validate_deployment_name(endpoint_name, deployment_name)

with open(request_file, "rb") as f:
with open(request_file, "rb") as f: # type: ignore[arg-type]
data = json.loads(f.read())
if local:
return self._local_endpoint_helper.invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,17 @@ def delete(self, name: str) -> None:
@monitor_with_activity(logger, "WorkspaceConnections.List", ActivityType.PUBLICAPI)
def list(
self,
*,
connection_type: Optional[str] = None,
include_data_connections: bool = False,
**kwargs: Any,
) -> Iterable[WorkspaceConnection]:
"""List all workspace connections for a workspace.

:param connection_type: Type of workspace connection to list.
:type connection_type: Optional[str]
:param include_data_connections: If true, also return data connections. Defaults to False.
:type include_data_connections: bool
:keyword connection_type: Type of workspace connection to list.
:paramtype connection_type: Optional[str]
:keyword include_data_connections: If true, also return data connections. Defaults to False.
:paramtype include_data_connections: bool
:return: An iterator like instance of workspace connection objects
:rtype: Iterable[~azure.ai.ml.entities.WorkspaceConnection]

Expand Down