Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refine
  • Loading branch information
bupt-wenxiaole committed Feb 23, 2024
commit 0786a7af2aa0a1378f055803c2792e6cd8737f08
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from azure.ai.ml._schema.job.input_output_entry import InputLiteralValueSchema
from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
from azure.ai.ml.constants._common import LoggingLevel
from azure.ai.ml._schema.core.fields import DataBindingStr, UnionField

from ..core.fields import UnionField

Expand All @@ -37,47 +36,32 @@ class ParameterizedParallelSchema(PathAwareSchema):
input_data = fields.Str()
resources = NestedField(JobResourceConfigurationSchema)
retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE)
max_concurrency_per_instance = UnionField(
[
fields.Integer(
dump_default=1,
metadata={"description": "The max parallellism that each compute instance has."},
),
DataBindingStr()
]
max_concurrency_per_instance = fields.Integer(
dump_default=1,
metadata={"description": "The max parallellism that each compute instance has."},
)
error_threshold = UnionField(
[
fields.Integer(
dump_default=-1,
metadata={
"description": (
"The number of item processing failures should be ignored. "
"If the error_threshold is reached, the job terminates. "
"For a list of files as inputs, one item means one file reference. "
"This setting doesn't apply to command parallelization."
)
},
),
DataBindingStr()
]
error_threshold = fields.Integer(
dump_default=-1,
metadata={
"description": (
"The number of item processing failures should be ignored. "
"If the error_threshold is reached, the job terminates. "
"For a list of files as inputs, one item means one file reference. "
"This setting doesn't apply to command parallelization."
)
},
)
mini_batch_error_threshold = UnionField(
[
fields.Integer(
dump_default=-1,
metadata={
"description": (
"The number of mini batch processing failures should be ignored. "
"If the mini_batch_error_threshold is reached, the job terminates. "
"For a list of files as inputs, one item means one file reference. "
"This setting can be used by either command or python function parallelization. "
"Only one error_threshold setting can be used in one job."
)
},
),
DataBindingStr()
]
mini_batch_error_threshold = fields.Integer(
dump_default=-1,
metadata={
"description": (
"The number of mini batch processing failures should be ignored. "
"If the mini_batch_error_threshold is reached, the job terminates. "
"For a list of files as inputs, one item means one file reference. "
"This setting can be used by either command or python function parallelization. "
"Only one error_threshold setting can be used in one job."
)
},
)
environment_variables = UnionField(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from marshmallow import fields, post_load

from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
from azure.ai.ml._schema.core.fields import DataBindingStr, UnionField


class ResourceConfigurationSchema(metaclass=PatchedSchemaMeta):
instance_count = UnionField([fields.Int(), DataBindingStr()])
instance_count = fields.Int()
instance_type = fields.Str(metadata={"description": "The instance type to make available to this job."})
properties = fields.Dict(keys=fields.Str())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ def _to_rest_object(self, **kwargs: Any) -> dict:
"partition_keys": json.dumps(self.partition_keys)
if self.partition_keys is not None
else self.partition_keys,
"identity": get_rest_dict_for_node_attrs(self.identity),
"identity": self.identity._to_dict()
if self.identity and not isinstance(self.identity, Dict)
else None,
"resources": get_rest_dict_for_node_attrs(self.resources),
}
)
Expand Down Expand Up @@ -479,8 +481,9 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:

if "partition_keys" in obj and obj["partition_keys"]:
obj["partition_keys"] = json.dumps(obj["partition_keys"])

if "identity" in obj and obj["identity"]:
obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"])
obj["identity"] = _BaseJobIdentityConfiguration._load(obj["identity"])
return obj

def _build_inputs(self) -> Dict:
Expand Down