Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,20 @@ def _resolve_path(self, value) -> Path:
"""Resolve path to absolute path based on base_path in context.
Will resolve the path if it's already an absolute path.
"""
result = Path(value)
base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
if not result.is_absolute():
result = base_path / result
try:
return result.resolve()
result = Path(value)
base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
if not result.is_absolute():
result = base_path / result

# for non-path string like "azureml:/xxx", OSError can be raised in either
# resolve() or is_dir() or is_file()
result = result.resolve()
if (self._allow_dir and result.is_dir()) or (self._allow_file and result.is_file()):
return result
except OSError:
raise self.make_error("invalid_path")
raise self.make_error("path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type)

@property
def allowed_path_type(self) -> str:
Expand All @@ -154,16 +160,12 @@ def _validate(self, value):

if value is None:
return
path = self._resolve_path(value)
if (self._allow_dir and path.is_dir()) or (self._allow_file and path.is_file()):
return
raise self.make_error("path_not_exist", path=path.as_posix(), allow_type=self.allowed_path_type)
self._resolve_path(value)

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
# do not block serializing None even if required or not allow_none.
if value is None:
return None
self._validate(value)
# always dump path as absolute path in string as base_path will be dropped after serialization
return super(LocalPathField, self)._serialize(self._resolve_path(value).as_posix(), attr, obj, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def begin_create_or_update(
operation_config=self._operation_config,
)
if deployment.data_collector:
self._register_collection_data_assets(deployment= deployment)
self._register_collection_data_assets(deployment=deployment)

upload_dependencies(deployment, orchestrators)
try:
Expand Down Expand Up @@ -351,16 +351,13 @@ def _register_collection_data_assets(self, deployment: OnlineDeployment) -> None
for collection in deployment.data_collector.collections:
data_name = deployment.endpoint_name + "-" + deployment.name + "-" + collection
data_object = Data(
name = data_name,
path = deployment.data_collector.destination.path
name=data_name,
path=deployment.data_collector.destination.path
if deployment.data_collector.destination and deployment.data_collector.destination.path
else DEFAULT_MDC_PATH,
is_anonymous= True
)
is_anonymous=True,
)
result = self._all_operations._all_operations[AzureMLResourceType.DATA].create_or_update(data_object)
deployment.data_collector.collections[collection].data = DataAsset(
data_id = result.id,
path = result.path,
name = result.name,
version = result.version
)
data_id=result.id, path=result.path, name=result.name, version=result.version
)
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,21 @@ def test_dump_with_non_existent_base_path(self):
component_entity._base_path = "/non/existent/path"
component_entity._to_dict()

def test_arm_code_from_rest_object(self):
arm_code = (
"azureml:/subscriptions/xxx/resourceGroups/xxx/providers/Microsoft.MachineLearningServices/"
"workspaces/zzz/codes/90b33c11-365d-4ee4-aaa1-224a042deb41/versions/1"
)
yaml_path = "./tests/test_configs/components/helloworld_component.yml"
yaml_component = load_component(yaml_path)

from azure.ai.ml.entities import Component

rest_object = yaml_component._to_rest_object()
rest_object.properties.component_spec["code"] = arm_code
component = Component._from_rest_object(rest_object)
assert component.code == arm_code[8:]


@pytest.mark.timeout(_COMPONENT_TIMEOUT_SECOND)
@pytest.mark.unittest
Expand Down