Skip to content
Merged
Prev Previous commit
Next Next commit
fix mypy errors - 5
  • Loading branch information
CoderKevinZhang committed Feb 29, 2024
commit 58d380f3be997fb01954671ef891ffd862019844
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
tags=tags,
properties=properties,
display_name=display_name,
is_deterministic=is_deterministic,
is_deterministic=is_deterministic, # type: ignore[arg-type]
inputs=inputs,
outputs=outputs,
yaml_str=yaml_str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
**kwargs: Any,
):
self._arm_type: str = ""
self.latest_version: str = ""
self.latest_version: str = "" # type: ignore[assignment]
self.image: Optional[str] = None
inference_config = kwargs.pop("inference_config", None)
os_type = kwargs.pop("os_type", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def __init__(

if isinstance(component, FlowComponent):
# make input definition fit actual inputs for flow component
with component._inputs._fit_inputs(inputs): # pylint: disable=protected-access
# pylint: disable=protected-access
with component._inputs._fit_inputs(inputs): # type: ignore[attr-defined]
BaseNode.__init__(
self,
type=NodeType.PARALLEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
tags: Optional[Dict] = None,
properties: Optional[Dict] = None,
display_name: Optional[str] = None,
is_deterministic: Optional[bool] = True,
is_deterministic: bool = True,
inputs: Optional[Dict] = None,
outputs: Optional[Dict] = None,
yaml_str: Optional[str] = None,
Expand Down Expand Up @@ -423,7 +423,7 @@ def _from_container_rest_object(cls, component_container_rest_object: ComponentC
properties=component_container_details.properties,
type=NodeType._CONTAINER,
# Set this field to None as it hold a default True in init.
is_deterministic=None,
is_deterministic=None, # type: ignore[arg-type]
)
component.latest_version = component_container_details.latest_version
return component
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _to_rest_object(self) -> ComponentVersion:
def _func(self, **kwargs: Any) -> "Parallel": # pylint: disable=invalid-overridden-method
from azure.ai.ml.entities._builders.parallel import Parallel

with self._inputs._fit_inputs(kwargs): # pylint: disable=protected-access
with self._inputs._fit_inputs(kwargs): # type: ignore[attr-defined]
# pylint: disable=not-callable
return super()._func(**kwargs) # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
display_name=display_name,
inputs=inputs,
outputs=outputs,
is_deterministic=is_deterministic,
is_deterministic=is_deterministic, # type: ignore[arg-type]
**kwargs,
)
self._jobs = self._process_jobs(jobs) if jobs else {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def __init__(
:paramtype default: Union[str, int, float, bool]
:keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job
execution will fail.
:paramtype min: Union[int, float]
:paramtype min: Optional[float]
:keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job
execution will fail.
:paramtype max: Union[int, float]
:paramtype max: Optional[float]
:keyword optional: Specifies if the input is optional.
:paramtype optional: bool
:keyword description: Description of the input
Expand Down Expand Up @@ -150,10 +150,10 @@ def __init__(
:paramtype default: Union[str, int, float, bool]
:keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job
execution will fail.
:paramtype min: Union[int, float]
:paramtype min: Optional[int]
:keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job
execution will fail.
:paramtype max: Union[int, float]
:paramtype max: Optional[int]
:keyword optional: Specifies if the input is optional.
:paramtype optional: bool
:keyword description: Description of the input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class Output(_InputOutputBase):
@overload
def __init__(
self,
type: Any,
*,
type: str,
path: Optional[str] = None,
mode: Optional[str] = None,
description: Optional[str] = None,
Expand All @@ -34,6 +35,7 @@ def __init__(
@overload
def __init__(
self,
*,
type: Literal["uri_folder"] = "uri_folder",
path: Optional[str] = None,
mode: Optional[str] = None,
Expand Down Expand Up @@ -83,6 +85,7 @@ def __init__(
@overload
def __init__(
self,
*,
type: Literal["uri_file"] = "uri_file",
path: Optional[str] = None,
mode: Optional[str] = None,
Expand Down Expand Up @@ -111,6 +114,7 @@ def __init__(

def __init__(
self,
*,
type: str = AssetTypes.URI_FOLDER,
path: Optional[str] = None,
mode: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def _build_default_data(self) -> None:
if self._data is None:
# _meta will be None when node._component is not a Component object
# so we just leave the type inference work to backend
self._data = Output(type=None)
self._data = Output(type=None) # type: ignore[call-overload]

def _build_data(self, data: T) -> Any:
"""Build output data according to assigned input, eg: node.outputs.key = data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from abc import ABC
from typing import Any, Optional, cast
from typing import Any, Optional, Union, cast

from azure.ai.ml._restclient.v2023_04_01_preview.models import BanditPolicy as RestBanditPolicy
from azure.ai.ml._restclient.v2023_04_01_preview.models import EarlyTerminationPolicy as RestEarlyTerminationPolicy
Expand Down Expand Up @@ -77,8 +77,8 @@ def __init__(
*,
delay_evaluation: int = 0,
evaluation_interval: int = 0,
slack_amount: Optional[float] = 0,
slack_factor: Optional[float] = 0,
slack_amount: Union[float, None] = 0,
slack_factor: Union[float, None] = 0,
) -> None:
super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval)
self.type = EarlyTerminationPolicyType.BANDIT.lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
:type resources: ~azure.ai.ml.entities.ResourceConfiguration
"""
self.sampling_algorithm = sampling_algorithm
self.early_termination = early_termination
self.early_termination = early_termination # type: ignore[assignment]
self._limits = limits
self.search_space = search_space
self.queue_settings = queue_settings
Expand Down Expand Up @@ -270,7 +270,7 @@ def _get_rest_sampling_algorithm(self) -> RestSamplingAlgorithm:
)

@property
def early_termination(self) -> Any:
def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]:
"""Early termination policy for sweep job.

:returns: Early termination policy for sweep job.
Expand All @@ -285,6 +285,7 @@ def early_termination(self, value: Any) -> None:
:param value: Early termination policy for sweep job.
:type value: ~azure.ai.ml.entities._job.sweep.early_termination_policy.EarlyTerminationPolicy
"""
self._early_termination: Optional[Union[str, EarlyTerminationPolicy]]
if value is None:
self._early_termination = None
elif isinstance(value, EarlyTerminationPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Choice(SweepDistribution):
:caption: Using Choice distribution to set values for a hyperparameter sweep
"""

def __init__(self, values: Optional[List] = None, **kwargs: Any) -> None:
def __init__(self, values: Optional[List[Union[float, str, dict]]] = None, **kwargs: Any) -> None:
kwargs.setdefault(TYPE, SearchSpace.CHOICE)
super().__init__(**kwargs)
self.values = values
Expand Down Expand Up @@ -116,7 +116,7 @@ def _from_rest_object(cls, obj: List) -> "Choice":
from_rest_values.append(from_rest_dict)
else:
from_rest_values.append(rest_value)
return Choice(values=from_rest_values)
return Choice(values=from_rest_values) # type: ignore[arg-type]


class Normal(SweepDistribution):
Expand Down
35 changes: 28 additions & 7 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/sweep/sweep_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,19 @@
)
from .objective import Objective
from .parameterized_sweep import ParameterizedSweep
from .search_space import SweepDistribution
from .search_space import (
Choice,
LogNormal,
LogUniform,
Normal,
QLogNormal,
QLogUniform,
QNormal,
QUniform,
Randint,
SweepDistribution,
Uniform,
)

module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,9 +153,16 @@ def __init__(
compute: Optional[str] = None,
limits: Optional[SweepJobLimits] = None,
sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None,
search_space: Optional[Dict] = None,
search_space: Optional[
Dict[
str,
Union[
Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
],
]
] = None,
objective: Optional[Objective] = None,
trial: Optional[Union[CommandJob, CommandComponent, ParameterizedCommand]] = None,
trial: Optional[Union[CommandJob, CommandComponent]] = None,
early_termination: Optional[
Union[EarlyTerminationPolicy, BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy]
] = None,
Expand Down Expand Up @@ -267,6 +286,10 @@ def _load_from_rest(cls, obj: JobBase) -> "SweepJob":
# Compute also appears in both layers of the yaml, but only one of the REST.
# This should be a required field in one place, but cannot be if its optional in two

_search_space = {}
for param, dist in properties.search_space.items():
_search_space[param] = SweepDistribution._from_rest_object(dist)

return SweepJob(
name=obj.name,
id=obj.id,
Expand All @@ -278,12 +301,10 @@ def _load_from_rest(cls, obj: JobBase) -> "SweepJob":
services=properties.services,
status=properties.status,
creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
trial=trial,
trial=trial, # type: ignore[arg-type]
compute=properties.compute_id,
sampling_algorithm=sampling_algorithm,
search_space={
param: SweepDistribution._from_rest_object(dist) for (param, dist) in properties.search_space.items()
},
search_space=_search_space, # type: ignore[arg-type]
limits=SweepJobLimits._from_rest_object(properties.limits),
early_termination=early_termination,
objective=properties.objective,
Expand Down