diff --git a/gms/api/src/main/snapshot/com.linkedin.entity.entities.snapshot.json b/gms/api/src/main/snapshot/com.linkedin.entity.entities.snapshot.json index 6ce21bf375603..6ee6ff6970a20 100644 --- a/gms/api/src/main/snapshot/com.linkedin.entity.entities.snapshot.json +++ b/gms/api/src/main/snapshot/com.linkedin.entity.entities.snapshot.json @@ -3397,6 +3397,7 @@ "name" : "MLFeatureTableProperties", "namespace" : "com.linkedin.ml.metadata", "doc" : "Properties associated with a MLFeatureTable", + "include" : [ "com.linkedin.common.CustomProperties" ], "fields" : [ { "name" : "description", "type" : "string", diff --git a/metadata-ingestion/README.md b/metadata-ingestion/README.md index b7a0b3f8dc8d2..9cf780f9e5b70 100644 --- a/metadata-ingestion/README.md +++ b/metadata-ingestion/README.md @@ -45,6 +45,7 @@ We use a plugin architecture so that you can install only the dependencies you a | oracle | `pip install 'acryl-datahub[oracle]'` | Oracle source | | postgres | `pip install 'acryl-datahub[postgres]'` | Postgres source | | redshift | `pip install 'acryl-datahub[redshift]'` | Redshift source | +| sagemaker | `pip install 'acryl-datahub[sagemaker]'` | AWS SageMaker source | | sqlalchemy | `pip install 'acryl-datahub[sqlalchemy]'` | Generic SQLAlchemy source | | snowflake | `pip install 'acryl-datahub[snowflake]'` | Snowflake source | | snowflake-usage | `pip install 'acryl-datahub[snowflake-usage]'` | Snowflake usage statistics source | @@ -345,6 +346,27 @@ source: # options is same as above ``` +### AWS SageMaker `sagemaker` + +Extracts: + +- Feature groups (support for models, jobs, and more coming soon!) + +```yml +source: + type: sagemaker + config: + aws_region: # aws_region_name, i.e. "eu-west-1" + env: # environment for the DatasetSnapshot URN, one of "DEV", "EI", "PROD" or "CORP". Defaults to "PROD". + + # Credentials. If not specified here, these are picked up according to boto3 rules. + # (see https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) + aws_access_key_id: # Optional. + aws_secret_access_key: # Optional. + aws_session_token: # Optional. + aws_role: # Optional (Role chaining supported by using a sorted list). +``` + ### Snowflake `snowflake` Extracts: diff --git a/metadata-ingestion/examples/recipes/sagemaker_to_datahub.yml b/metadata-ingestion/examples/recipes/sagemaker_to_datahub.yml new file mode 100644 index 0000000000000..e1914d38bf31f --- /dev/null +++ b/metadata-ingestion/examples/recipes/sagemaker_to_datahub.yml @@ -0,0 +1,10 @@ +# in this example, AWS creds are detected automatically – see the README for more details +source: + type: sagemaker + config: + aws_region: "us-west-2" + +sink: + type: "datahub-rest" + config: + server: "http://localhost:8080" diff --git a/metadata-ingestion/scripts/update_golden_files.sh b/metadata-ingestion/scripts/update_golden_files.sh index 898700a2822ec..6564d5f996d87 100755 --- a/metadata-ingestion/scripts/update_golden_files.sh +++ b/metadata-ingestion/scripts/update_golden_files.sh @@ -6,17 +6,18 @@ set -euxo pipefail pytest --basetemp=tmp || true # Update the golden files. -cp tmp/test_serde_to_json_tests_unit_0/output.json tests/unit/serde/test_serde_large.json -cp tmp/test_serde_to_json_tests_unit_1/output.json tests/unit/serde/test_serde_chart_snapshot.json -cp tmp/test_ldap_ingest0/ldap_mces.json tests/integration/ldap/ldap_mces_golden.json -cp tmp/test_mysql_ingest0/mysql_mces.json tests/integration/mysql/mysql_mces_golden.json -cp tmp/test_mssql_ingest0/mssql_mces.json tests/integration/sql_server/mssql_mces_golden.json -cp tmp/test_mongodb_ingest0/mongodb_mces.json tests/integration/mongodb/mongodb_mces_golden.json -cp tmp/test_feast_ingest0/feast_mces.json tests/integration/feast/feast_mces_golden.json -cp tmp/test_dbt_ingest0/dbt_mces.json tests/integration/dbt/dbt_mces_golden.json -cp tmp/test_glue_ingest0/glue_mces.json tests/unit/glue/glue_mces_golden.json -cp tmp/test_lookml_ingest0/lookml_mces.json tests/integration/lookml/expected_output.json -cp tmp/test_looker_ingest0/looker_mces.json tests/integration/looker/expected_output.json +cp tmp/test_serde_to_json_tests_unit_0/output.json tests/unit/serde/test_serde_large.json +cp tmp/test_serde_to_json_tests_unit_1/output.json tests/unit/serde/test_serde_chart_snapshot.json +cp tmp/test_ldap_ingest0/ldap_mces.json tests/integration/ldap/ldap_mces_golden.json +cp tmp/test_mysql_ingest0/mysql_mces.json tests/integration/mysql/mysql_mces_golden.json +cp tmp/test_mssql_ingest0/mssql_mces.json tests/integration/sql_server/mssql_mces_golden.json +cp tmp/test_mongodb_ingest0/mongodb_mces.json tests/integration/mongodb/mongodb_mces_golden.json +cp tmp/test_feast_ingest0/feast_mces.json tests/integration/feast/feast_mces_golden.json +cp tmp/test_dbt_ingest0/dbt_mces.json tests/integration/dbt/dbt_mces_golden.json +cp tmp/test_glue_ingest0/glue_mces.json tests/unit/glue/glue_mces_golden.json +cp tmp/test_sagemaker_ingest0/sagemaker_mces.json tests/unit/sagemaker/sagemaker_mces_golden.json +cp tmp/test_lookml_ingest0/lookml_mces.json tests/integration/lookml/expected_output.json +cp tmp/test_looker_ingest0/looker_mces.json tests/integration/looker/expected_output.json cp tmp/test_bq_usage_source0/bigquery_usages.json tests/integration/bigquery-usage/bigquery_usages_golden.json # Print success message. diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 4e15654cbdd3c..3b6994ca858bc 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -57,6 +57,11 @@ def get_long_description(): "sqlalchemy==1.3.24", } +aws_common = { + # AWS Python SDK + "boto3" +} + # Note: for all of these, framework_common will be added. plugins: Dict[str, Set[str]] = { # Sink plugins. @@ -73,7 +78,7 @@ def get_long_description(): "bigquery-usage": {"google-cloud-logging", "cachetools"}, "druid": sql_common | {"pydruid>=0.6.2"}, "feast": {"docker"}, - "glue": {"boto3"}, + "glue": aws_common, "hive": sql_common | { # Acryl Data maintains a fork of PyHive, which adds support for table comments @@ -90,6 +95,7 @@ def get_long_description(): "oracle": sql_common | {"cx_Oracle"}, "postgres": sql_common | {"psycopg2-binary", "GeoAlchemy2"}, "redshift": sql_common | {"sqlalchemy-redshift", "psycopg2-binary", "GeoAlchemy2"}, + "sagemaker": aws_common, "snowflake": sql_common | {"snowflake-sqlalchemy"}, "snowflake-usage": sql_common | {"snowflake-sqlalchemy"}, "superset": {"requests"}, @@ -150,6 +156,7 @@ def get_long_description(): "glue", "hive", "oracle", + "sagemaker", "datahub-kafka", "datahub-rest", # airflow is added below @@ -188,6 +195,7 @@ def get_long_description(): "druid = datahub.ingestion.source.druid:DruidSource", "feast = datahub.ingestion.source.feast:FeastSource", "glue = datahub.ingestion.source.glue:GlueSource", + "sagemaker = datahub.ingestion.source.sagemaker:SagemakerSource", "hive = datahub.ingestion.source.hive:HiveSource", "kafka = datahub.ingestion.source.kafka:KafkaSource", "kafka-connect = datahub.ingestion.source.kafka_connect:KafkaConnectSource", diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws_common.py new file mode 100644 index 0000000000000..03a0c44699fb1 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/aws_common.py @@ -0,0 +1,87 @@ +from functools import reduce +from typing import List, Optional, Union + +import boto3 + +from datahub.configuration import ConfigModel +from datahub.configuration.common import AllowDenyPattern + + +def assume_role( + role_arn: str, aws_region: str, credentials: Optional[dict] = None +) -> dict: + credentials = credentials or {} + sts_client = boto3.client( + "sts", + region_name=aws_region, + aws_access_key_id=credentials.get("AccessKeyId"), + aws_secret_access_key=credentials.get("SecretAccessKey"), + aws_session_token=credentials.get("SessionToken"), + ) + + assumed_role_object = sts_client.assume_role( + RoleArn=role_arn, RoleSessionName="DatahubIngestionSource" + ) + return assumed_role_object["Credentials"] + + +class AwsSourceConfig(ConfigModel): + """ + Common AWS credentials config. + + Currently used by: + - Glue source + - SageMaker source + """ + + env: str = "PROD" + + database_pattern: AllowDenyPattern = AllowDenyPattern.allow_all() + table_pattern: AllowDenyPattern = AllowDenyPattern.allow_all() + + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None + aws_role: Optional[Union[str, List[str]]] = None + aws_region: str + + def get_client(self, service: str) -> boto3.client: + if ( + self.aws_access_key_id + and self.aws_secret_access_key + and self.aws_session_token + ): + return boto3.client( + service, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + region_name=self.aws_region, + ) + elif self.aws_access_key_id and self.aws_secret_access_key: + return boto3.client( + service, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + region_name=self.aws_region, + ) + elif self.aws_role: + if isinstance(self.aws_role, str): + credentials = assume_role(self.aws_role, self.aws_region) + else: + credentials = reduce( + lambda new_credentials, role_arn: assume_role( + role_arn, self.aws_region, new_credentials + ), + self.aws_role, + {}, + ) + return boto3.client( + service, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + region_name=self.aws_region, + ) + else: + return boto3.client(service, region_name=self.aws_region) diff --git a/metadata-ingestion/src/datahub/ingestion/source/glue.py b/metadata-ingestion/src/datahub/ingestion/source/glue.py index 5c79d99888d78..e22bcfd8a1929 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/glue.py +++ b/metadata-ingestion/src/datahub/ingestion/source/glue.py @@ -3,18 +3,14 @@ from collections import defaultdict from dataclasses import dataclass from dataclasses import field as dataclass_field -from functools import reduce from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union from urllib.parse import urlparse -import boto3 - -from datahub.configuration import ConfigModel -from datahub.configuration.common import AllowDenyPattern from datahub.emitter import mce_builder from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.source import Source, SourceReport from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.aws_common import AwsSourceConfig from datahub.metadata.com.linkedin.pegasus2avro.common import AuditStamp, Status from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent @@ -49,79 +45,10 @@ ) -def assume_role( - role_arn: str, aws_region: str, credentials: Optional[dict] = None -) -> dict: - credentials = credentials or {} - sts_client = boto3.client( - "sts", - region_name=aws_region, - aws_access_key_id=credentials.get("AccessKeyId"), - aws_secret_access_key=credentials.get("SecretAccessKey"), - aws_session_token=credentials.get("SessionToken"), - ) - - assumed_role_object = sts_client.assume_role( - RoleArn=role_arn, RoleSessionName="DatahubIngestionSourceGlue" - ) - return assumed_role_object["Credentials"] - - -class GlueSourceConfig(ConfigModel): - env: str = "PROD" - - database_pattern: AllowDenyPattern = AllowDenyPattern.allow_all() - table_pattern: AllowDenyPattern = AllowDenyPattern.allow_all() +class GlueSourceConfig(AwsSourceConfig): extract_transforms: Optional[bool] = True - aws_access_key_id: Optional[str] = None - aws_secret_access_key: Optional[str] = None - aws_session_token: Optional[str] = None - aws_role: Optional[Union[str, List[str]]] = None - aws_region: str - - def get_client(self, service: str) -> boto3.client: - if ( - self.aws_access_key_id - and self.aws_secret_access_key - and self.aws_session_token - ): - return boto3.client( - service, - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - aws_session_token=self.aws_session_token, - region_name=self.aws_region, - ) - elif self.aws_access_key_id and self.aws_secret_access_key: - return boto3.client( - service, - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - region_name=self.aws_region, - ) - elif self.aws_role: - if isinstance(self.aws_role, str): - credentials = assume_role(self.aws_role, self.aws_region) - else: - credentials = reduce( - lambda new_credentials, role_arn: assume_role( - role_arn, self.aws_region, new_credentials - ), - self.aws_role, - {}, - ) - return boto3.client( - service, - aws_access_key_id=credentials["AccessKeyId"], - aws_secret_access_key=credentials["SecretAccessKey"], - aws_session_token=credentials["SessionToken"], - region_name=self.aws_region, - ) - else: - return boto3.client(service, region_name=self.aws_region) - @property def glue_client(self): return self.get_client("glue") diff --git a/metadata-ingestion/src/datahub/ingestion/source/sagemaker.py b/metadata-ingestion/src/datahub/ingestion/source/sagemaker.py new file mode 100644 index 0000000000000..db4e762843106 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/sagemaker.py @@ -0,0 +1,301 @@ +from dataclasses import dataclass +from dataclasses import field as dataclass_field +from typing import Any, Dict, Iterable, List + +import datahub.emitter.mce_builder as builder +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.api.source import Source, SourceReport +from datahub.ingestion.source.aws_common import AwsSourceConfig +from datahub.ingestion.source.metadata_common import MetadataWorkUnit +from datahub.metadata.com.linkedin.pegasus2avro.common import MLFeatureDataType +from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import ( + MLFeatureSnapshot, + MLFeatureTableSnapshot, + MLPrimaryKeySnapshot, +) +from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent +from datahub.metadata.schema_classes import ( + MLFeaturePropertiesClass, + MLFeatureTablePropertiesClass, + MLPrimaryKeyPropertiesClass, +) + + +class SagemakerSourceConfig(AwsSourceConfig): + @property + def sagemaker_client(self): + return self.get_client("sagemaker") + + +@dataclass +class SagemakerSourceReport(SourceReport): + tables_scanned = 0 + filtered: List[str] = dataclass_field(default_factory=list) + + def report_table_scanned(self) -> None: + self.tables_scanned += 1 + + def report_table_dropped(self, table: str) -> None: + self.filtered.append(table) + + +class SagemakerSource(Source): + source_config: SagemakerSourceConfig + report = SagemakerSourceReport() + + def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext): + super().__init__(ctx) + self.source_config = config + self.report = SagemakerSourceReport() + self.sagemaker_client = config.sagemaker_client + self.env = config.env + + @classmethod + def create(cls, config_dict, ctx): + config = SagemakerSourceConfig.parse_obj(config_dict) + return cls(config, ctx) + + def get_all_feature_groups(self) -> List[Dict[str, Any]]: + """ + List all feature groups in SageMaker. + """ + + feature_groups = [] + + # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_feature_groups + paginator = self.sagemaker_client.get_paginator("list_feature_groups") + for page in paginator.paginate(): + feature_groups += page["FeatureGroupSummaries"] + + return feature_groups + + def get_feature_group_details(self, feature_group_name: str) -> Dict[str, Any]: + """ + Get details of a feature group (including list of component features). + """ + + # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_feature_group + feature_group = self.sagemaker_client.describe_feature_group( + FeatureGroupName=feature_group_name + ) + + # use falsy fallback since AWS stubs require this to be a string in tests + next_token = feature_group.get("NextToken", "") + + # paginate over feature group features + while next_token: + next_features = self.sagemaker_client.describe_feature_group( + FeatureGroupName=feature_group_name, NextToken=next_token + ) + feature_group["FeatureDefinitions"].append( + next_features["FeatureDefinitions"] + ) + next_token = feature_group.get("NextToken", "") + + return feature_group + + def get_feature_group_wu( + self, feature_group_details: Dict[str, Any] + ) -> MetadataWorkUnit: + """ + Generate an MLFeatureTable workunit for a SageMaker feature group. + + Parameters + ---------- + feature_group_details: + ingested SageMaker feature group from get_feature_group_details() + """ + + feature_group_name = feature_group_details["FeatureGroupName"] + + feature_group_snapshot = MLFeatureTableSnapshot( + urn=builder.make_ml_feature_table_urn("sagemaker", feature_group_name), + aspects=[], + ) + + feature_group_snapshot.aspects.append( + MLFeatureTablePropertiesClass( + description=feature_group_details.get("Description"), + # non-primary key features + mlFeatures=[ + builder.make_ml_feature_urn( + feature_group_name, + feature["FeatureName"], + ) + for feature in feature_group_details["FeatureDefinitions"] + if feature["FeatureName"] + != feature_group_details["RecordIdentifierFeatureName"] + ], + mlPrimaryKeys=[ + builder.make_ml_primary_key_urn( + feature_group_name, + feature_group_details["RecordIdentifierFeatureName"], + ) + ], + # additional metadata + customProperties={ + "arn": feature_group_details["FeatureGroupArn"], + "creation_time": str(feature_group_details["CreationTime"]), + "status": feature_group_details["FeatureGroupStatus"], + }, + ) + ) + + # make the MCE and workunit + mce = MetadataChangeEvent(proposedSnapshot=feature_group_snapshot) + return MetadataWorkUnit(id=feature_group_name, mce=mce) + + field_type_mappings = { + "String": MLFeatureDataType.TEXT, + "Integral": MLFeatureDataType.ORDINAL, + "Fractional": MLFeatureDataType.CONTINUOUS, + } + + def get_feature_type(self, aws_type: str, feature_name: str) -> str: + + mapped_type = self.field_type_mappings.get(aws_type) + + if mapped_type is None: + self.report.report_warning( + feature_name, f"unable to map type {aws_type} to metadata schema" + ) + mapped_type = MLFeatureDataType.UNKNOWN + + return mapped_type + + def get_feature_wu( + self, feature_group_details: Dict[str, Any], feature: Dict[str, Any] + ) -> MetadataWorkUnit: + """ + Generate an MLFeature workunit for a SageMaker feature. + + Parameters + ---------- + feature_group_details: + ingested SageMaker feature group from get_feature_group_details() + feature: + ingested SageMaker feature + """ + + # if the feature acts as the record identifier, then we ingest it as an MLPrimaryKey + # the RecordIdentifierFeatureName is guaranteed to exist as it's required on creation + is_record_identifier = ( + feature_group_details["RecordIdentifierFeatureName"] + == feature["FeatureName"] + ) + + feature_sources = [] + + if "OfflineStoreConfig" in feature_group_details: + + # remove S3 prefix (s3://) + s3_name = feature_group_details["OfflineStoreConfig"]["S3StorageConfig"][ + "S3Uri" + ][5:] + + if s3_name.endswith("/"): + s3_name = s3_name[:-1] + + feature_sources.append( + builder.make_dataset_urn( + "s3", + s3_name, + self.source_config.env, + ) + ) + + if "DataCatalogConfig" in feature_group_details["OfflineStoreConfig"]: + + # if Glue catalog associated with offline store + glue_database = feature_group_details["OfflineStoreConfig"][ + "DataCatalogConfig" + ]["Database"] + glue_table = feature_group_details["OfflineStoreConfig"][ + "DataCatalogConfig" + ]["TableName"] + + full_table_name = f"{glue_database}.{glue_table}" + + self.report.report_warning( + full_table_name, + f"""Note: table {full_table_name} is an AWS Glue object. + To view full table metadata, run Glue ingestion + (see https://datahubproject.io/docs/metadata-ingestion/#aws-glue-glue)""", + ) + + feature_sources.append( + f"urn:li:dataset:(urn:li:dataPlatform:glue,{full_table_name},{self.source_config.env})" + ) + + # note that there's also an OnlineStoreConfig field, but this + # lacks enough metadata to create a dataset + # (only specifies the security config and whether it's enabled at all) + + # append feature name and type + if is_record_identifier: + primary_key_snapshot: MLPrimaryKeySnapshot = MLPrimaryKeySnapshot( + urn=builder.make_ml_primary_key_urn( + feature_group_details["FeatureGroupName"], + feature["FeatureName"], + ), + aspects=[ + MLPrimaryKeyPropertiesClass( + dataType=self.get_feature_type( + feature["FeatureType"], feature["FeatureName"] + ), + sources=feature_sources, + ), + ], + ) + + # make the MCE and workunit + mce = MetadataChangeEvent(proposedSnapshot=primary_key_snapshot) + else: + # create snapshot instance for the feature + feature_snapshot: MLFeatureSnapshot = MLFeatureSnapshot( + urn=builder.make_ml_feature_urn( + feature_group_details["FeatureGroupName"], + feature["FeatureName"], + ), + aspects=[ + MLFeaturePropertiesClass( + dataType=self.get_feature_type( + feature["FeatureType"], feature["FeatureName"] + ), + sources=feature_sources, + ) + ], + ) + + # make the MCE and workunit + mce = MetadataChangeEvent(proposedSnapshot=feature_snapshot) + + return MetadataWorkUnit( + id=f'{feature_group_details["FeatureGroupName"]}-{feature["FeatureName"]}', + mce=mce, + ) + + def get_workunits(self) -> Iterable[MetadataWorkUnit]: + + feature_groups = self.get_all_feature_groups() + + for feature_group in feature_groups: + + feature_group_details = self.get_feature_group_details( + feature_group["FeatureGroupName"] + ) + + for feature in feature_group_details["FeatureDefinitions"]: + wu = self.get_feature_wu(feature_group_details, feature) + self.report.report_workunit(wu) + yield wu + + wu = self.get_feature_group_wu(feature_group_details) + self.report.report_workunit(wu) + yield wu + + def get_report(self): + return self.report + + def close(self): + pass diff --git a/metadata-ingestion/src/datahub/metadata/schema.avsc b/metadata-ingestion/src/datahub/metadata/schema.avsc index 446a9b6f3eb79..6aebdcc29a5ae 100644 --- a/metadata-ingestion/src/datahub/metadata/schema.avsc +++ b/metadata-ingestion/src/datahub/metadata/schema.avsc @@ -4200,6 +4200,15 @@ "name": "MLFeatureTableProperties", "namespace": "com.linkedin.pegasus2avro.ml.metadata", "fields": [ + { + "type": { + "type": "map", + "values": "string" + }, + "name": "customProperties", + "default": {}, + "doc": "Custom property bag." + }, { "type": [ "null", diff --git a/metadata-ingestion/src/datahub/metadata/schema_classes.py b/metadata-ingestion/src/datahub/metadata/schema_classes.py index 17f7e861eee06..89e5ab1a5ea2b 100644 --- a/metadata-ingestion/src/datahub/metadata/schema_classes.py +++ b/metadata-ingestion/src/datahub/metadata/schema_classes.py @@ -5465,12 +5465,17 @@ class MLFeatureTablePropertiesClass(DictWrapper): RECORD_SCHEMA = get_schema_type("com.linkedin.pegasus2avro.ml.metadata.MLFeatureTableProperties") def __init__(self, + customProperties: Optional[Dict[str, str]]=None, description: Union[None, str]=None, mlFeatures: Union[None, List[str]]=None, mlPrimaryKeys: Union[None, List[str]]=None, ): super().__init__() + if customProperties is None: + self.customProperties = {} + else: + self.customProperties = customProperties self.description = description self.mlFeatures = mlFeatures self.mlPrimaryKeys = mlPrimaryKeys @@ -5483,11 +5488,24 @@ def construct_with_defaults(cls) -> "MLFeatureTablePropertiesClass": return self def _restore_defaults(self) -> None: + self.customProperties = dict() self.description = self.RECORD_SCHEMA.field_map["description"].default self.mlFeatures = self.RECORD_SCHEMA.field_map["mlFeatures"].default self.mlPrimaryKeys = self.RECORD_SCHEMA.field_map["mlPrimaryKeys"].default + @property + def customProperties(self) -> Dict[str, str]: + """Getter: Custom property bag.""" + return self._inner_dict.get('customProperties') # type: ignore + + + @customProperties.setter + def customProperties(self, value: Dict[str, str]) -> None: + """Setter: Custom property bag.""" + self._inner_dict['customProperties'] = value + + @property def description(self) -> Union[None, str]: """Getter: Documentation of the MLFeatureTable""" diff --git a/metadata-ingestion/src/datahub/metadata/schemas/MetadataAuditEvent.avsc b/metadata-ingestion/src/datahub/metadata/schemas/MetadataAuditEvent.avsc index b547ea471f459..80d3c3f93c3e0 100644 --- a/metadata-ingestion/src/datahub/metadata/schemas/MetadataAuditEvent.avsc +++ b/metadata-ingestion/src/datahub/metadata/schemas/MetadataAuditEvent.avsc @@ -4143,6 +4143,15 @@ "namespace": "com.linkedin.pegasus2avro.ml.metadata", "doc": "Properties associated with a MLFeatureTable", "fields": [ + { + "name": "customProperties", + "type": { + "type": "map", + "values": "string" + }, + "doc": "Custom property bag.", + "default": {} + }, { "name": "description", "type": [ diff --git a/metadata-ingestion/src/datahub/metadata/schemas/MetadataChangeEvent.avsc b/metadata-ingestion/src/datahub/metadata/schemas/MetadataChangeEvent.avsc index 49b407ded70d2..ae3452bd9d38b 100644 --- a/metadata-ingestion/src/datahub/metadata/schemas/MetadataChangeEvent.avsc +++ b/metadata-ingestion/src/datahub/metadata/schemas/MetadataChangeEvent.avsc @@ -4142,6 +4142,15 @@ "namespace": "com.linkedin.pegasus2avro.ml.metadata", "doc": "Properties associated with a MLFeatureTable", "fields": [ + { + "name": "customProperties", + "type": { + "type": "map", + "values": "string" + }, + "doc": "Custom property bag.", + "default": {} + }, { "name": "description", "type": [ diff --git a/metadata-ingestion/tests/integration/feast/feast_mces_golden.json b/metadata-ingestion/tests/integration/feast/feast_mces_golden.json index 2956bdd9dcdac..c253da63b4974 100644 --- a/metadata-ingestion/tests/integration/feast/feast_mces_golden.json +++ b/metadata-ingestion/tests/integration/feast/feast_mces_golden.json @@ -343,6 +343,7 @@ "aspects": [ { "com.linkedin.pegasus2avro.ml.metadata.MLFeatureTableProperties": { + "customProperties": {}, "description": null, "mlFeatures": [ "urn:li:mlFeature:(test_feature_table_all_feature_dtypes,test_BOOL_LIST_feature)", @@ -421,6 +422,7 @@ "aspects": [ { "com.linkedin.pegasus2avro.ml.metadata.MLFeatureTableProperties": { + "customProperties": {}, "description": null, "mlFeatures": [ "urn:li:mlFeature:(test_feature_table_no_labels,test_BYTES_feature)" @@ -485,6 +487,7 @@ "aspects": [ { "com.linkedin.pegasus2avro.ml.metadata.MLFeatureTableProperties": { + "customProperties": {}, "description": null, "mlFeatures": [ "urn:li:mlFeature:(test_feature_table_single_feature,test_BYTES_feature)" diff --git a/metadata-ingestion/tests/unit/sagemaker/sagemaker_mces_golden.json b/metadata-ingestion/tests/unit/sagemaker/sagemaker_mces_golden.json new file mode 100644 index 0000000000000..06ec895eb219d --- /dev/null +++ b/metadata-ingestion/tests/unit/sagemaker/sagemaker_mces_golden.json @@ -0,0 +1,286 @@ +[ + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureSnapshot": { + "urn": "urn:li:mlFeature:(test-2,some-feature-1)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureProperties": { + "description": null, + "dataType": "TEXT", + "version": null, + "sources": [ + "urn:li:dataset:(urn:li:dataPlatform:s3,datahub-sagemaker-outputs,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:glue,sagemaker_featurestore.test-2-123412341234,PROD)" + ] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLPrimaryKeySnapshot": { + "urn": "urn:li:mlPrimaryKey:(test-2,some-feature-2)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLPrimaryKeyProperties": { + "description": null, + "dataType": "ORDINAL", + "version": null, + "sources": [ + "urn:li:dataset:(urn:li:dataPlatform:s3,datahub-sagemaker-outputs,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:glue,sagemaker_featurestore.test-2-123412341234,PROD)" + ] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureSnapshot": { + "urn": "urn:li:mlFeature:(test-2,some-feature-3)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureProperties": { + "description": null, + "dataType": "CONTINUOUS", + "version": null, + "sources": [ + "urn:li:dataset:(urn:li:dataPlatform:s3,datahub-sagemaker-outputs,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:glue,sagemaker_featurestore.test-2-123412341234,PROD)" + ] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureTableSnapshot": { + "urn": "urn:li:mlFeatureTable:(urn:li:dataPlatform:sagemaker,test-2)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureTableProperties": { + "customProperties": { + "arn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test-2", + "creation_time": "2021-06-24 09:48:37.035000", + "status": "Created" + }, + "description": "Yet another test feature group", + "mlFeatures": [ + "urn:li:mlFeature:(test-2,some-feature-1)", + "urn:li:mlFeature:(test-2,some-feature-3)" + ], + "mlPrimaryKeys": [ + "urn:li:mlPrimaryKey:(test-2,some-feature-2)" + ] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureSnapshot": { + "urn": "urn:li:mlFeature:(test-1,name)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureProperties": { + "description": null, + "dataType": "TEXT", + "version": null, + "sources": [] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLPrimaryKeySnapshot": { + "urn": "urn:li:mlPrimaryKey:(test-1,id)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLPrimaryKeyProperties": { + "description": null, + "dataType": "ORDINAL", + "version": null, + "sources": [] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureSnapshot": { + "urn": "urn:li:mlFeature:(test-1,height)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureProperties": { + "description": null, + "dataType": "CONTINUOUS", + "version": null, + "sources": [] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureSnapshot": { + "urn": "urn:li:mlFeature:(test-1,time)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureProperties": { + "description": null, + "dataType": "TEXT", + "version": null, + "sources": [] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureTableSnapshot": { + "urn": "urn:li:mlFeatureTable:(urn:li:dataPlatform:sagemaker,test-1)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureTableProperties": { + "customProperties": { + "arn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test-1", + "creation_time": "2021-06-23 13:58:10.264000", + "status": "Created" + }, + "description": "First test feature group", + "mlFeatures": [ + "urn:li:mlFeature:(test-1,name)", + "urn:li:mlFeature:(test-1,height)", + "urn:li:mlFeature:(test-1,time)" + ], + "mlPrimaryKeys": [ + "urn:li:mlPrimaryKey:(test-1,id)" + ] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLPrimaryKeySnapshot": { + "urn": "urn:li:mlPrimaryKey:(test,feature_1)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLPrimaryKeyProperties": { + "description": null, + "dataType": "TEXT", + "version": null, + "sources": [] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureSnapshot": { + "urn": "urn:li:mlFeature:(test,feature_2)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureProperties": { + "description": null, + "dataType": "ORDINAL", + "version": null, + "sources": [] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureSnapshot": { + "urn": "urn:li:mlFeature:(test,feature_3)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureProperties": { + "description": null, + "dataType": "CONTINUOUS", + "version": null, + "sources": [] + } + } + ] + } + }, + "proposedDelta": null + }, + { + "auditHeader": null, + "proposedSnapshot": { + "com.linkedin.pegasus2avro.metadata.snapshot.MLFeatureTableSnapshot": { + "urn": "urn:li:mlFeatureTable:(urn:li:dataPlatform:sagemaker,test)", + "aspects": [ + { + "com.linkedin.pegasus2avro.ml.metadata.MLFeatureTableProperties": { + "customProperties": { + "arn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test", + "creation_time": "2021-06-14 11:03:00.803000", + "status": "Created" + }, + "description": null, + "mlFeatures": [ + "urn:li:mlFeature:(test,feature_2)", + "urn:li:mlFeature:(test,feature_3)" + ], + "mlPrimaryKeys": [ + "urn:li:mlPrimaryKey:(test,feature_1)" + ] + } + } + ] + } + }, + "proposedDelta": null + } +] \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py new file mode 100644 index 0000000000000..ee2f23441f527 --- /dev/null +++ b/metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py @@ -0,0 +1,73 @@ +import json + +from botocore.stub import Stubber +from freezegun import freeze_time + +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.source.sagemaker import SagemakerSource, SagemakerSourceConfig +from tests.test_helpers import mce_helpers +from tests.unit.test_sagemaker_source_stubs import ( + describe_feature_group_response_1, + describe_feature_group_response_2, + describe_feature_group_response_3, + list_feature_groups_response, +) + +FROZEN_TIME = "2020-04-14 07:00:00" + + +def sagemaker_source() -> SagemakerSource: + return SagemakerSource( + ctx=PipelineContext(run_id="sagemaker-source-test"), + config=SagemakerSourceConfig(aws_region="us-west-2"), + ) + + +@freeze_time(FROZEN_TIME) +def test_sagemaker_ingest(tmp_path, pytestconfig): + + sagemaker_source_instance = sagemaker_source() + + with Stubber(sagemaker_source_instance.sagemaker_client) as sagemaker_stubber: + + sagemaker_stubber.add_response( + "list_feature_groups", + list_feature_groups_response, + {}, + ) + sagemaker_stubber.add_response( + "describe_feature_group", + describe_feature_group_response_1, + { + "FeatureGroupName": "test-2", + }, + ) + sagemaker_stubber.add_response( + "describe_feature_group", + describe_feature_group_response_2, + { + "FeatureGroupName": "test-1", + }, + ) + sagemaker_stubber.add_response( + "describe_feature_group", + describe_feature_group_response_3, + { + "FeatureGroupName": "test", + }, + ) + + mce_objects = [ + wu.mce.to_obj() for wu in sagemaker_source_instance.get_workunits() + ] + + with open(str(tmp_path / "sagemaker_mces.json"), "w") as f: + json.dump(mce_objects, f, indent=2) + + output = mce_helpers.load_json_file(str(tmp_path / "sagemaker_mces.json")) + + test_resources_dir = pytestconfig.rootpath / "tests/unit/sagemaker" + golden = mce_helpers.load_json_file( + str(test_resources_dir / "sagemaker_mces_golden.json") + ) + mce_helpers.assert_mces_equal(output, golden) diff --git a/metadata-ingestion/tests/unit/test_sagemaker_source_stubs.py b/metadata-ingestion/tests/unit/test_sagemaker_source_stubs.py new file mode 100644 index 0000000000000..aebf0e524acc1 --- /dev/null +++ b/metadata-ingestion/tests/unit/test_sagemaker_source_stubs.py @@ -0,0 +1,97 @@ +import datetime + +list_feature_groups_response = { + "FeatureGroupSummaries": [ + { + "FeatureGroupName": "test-2", + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test-2", + "CreationTime": datetime.datetime(2021, 6, 24, 9, 48, 37, 35000), + "FeatureGroupStatus": "Created", + }, + { + "FeatureGroupName": "test-1", + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test-1", + "CreationTime": datetime.datetime(2021, 6, 23, 13, 58, 10, 264000), + "FeatureGroupStatus": "Created", + }, + { + "FeatureGroupName": "test", + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test", + "CreationTime": datetime.datetime(2021, 6, 14, 11, 3, 0, 803000), + "FeatureGroupStatus": "Created", + }, + ], + "NextToken": "", +} + +describe_feature_group_response_1 = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test-2", + "FeatureGroupName": "test-2", + "RecordIdentifierFeatureName": "some-feature-2", + "EventTimeFeatureName": "some-feature-3", + "FeatureDefinitions": [ + {"FeatureName": "some-feature-1", "FeatureType": "String"}, + {"FeatureName": "some-feature-2", "FeatureType": "Integral"}, + {"FeatureName": "some-feature-3", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime(2021, 6, 24, 9, 48, 37, 35000), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://datahub-sagemaker-outputs", + "ResolvedOutputS3Uri": "s3://datahub-sagemaker-outputs/123412341234/sagemaker/us-west-2/offline-store/test-2-123412341234/data", + }, + "DisableGlueTableCreation": False, + "DataCatalogConfig": { + "TableName": "test-2-123412341234", + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + }, + }, + "RoleArn": "arn:aws:iam::123412341234:role/service-role/AmazonSageMaker-ExecutionRole-20210614T104201", + "FeatureGroupStatus": "Created", + "Description": "Yet another test feature group", + "NextToken": "", +} + +describe_feature_group_response_2 = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test-1", + "FeatureGroupName": "test-1", + "RecordIdentifierFeatureName": "id", + "EventTimeFeatureName": "time", + "FeatureDefinitions": [ + {"FeatureName": "name", "FeatureType": "String"}, + {"FeatureName": "id", "FeatureType": "Integral"}, + {"FeatureName": "height", "FeatureType": "Fractional"}, + {"FeatureName": "time", "FeatureType": "String"}, + ], + "CreationTime": datetime.datetime(2021, 6, 23, 13, 58, 10, 264000), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "FeatureGroupStatus": "Created", + "Description": "First test feature group", + "NextToken": "", +} + +describe_feature_group_response_3 = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123412341234:feature-group/test", + "FeatureGroupName": "test", + "RecordIdentifierFeatureName": "feature_1", + "EventTimeFeatureName": "feature_3", + "FeatureDefinitions": [ + {"FeatureName": "feature_1", "FeatureType": "String"}, + {"FeatureName": "feature_2", "FeatureType": "Integral"}, + {"FeatureName": "feature_3", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime( + 2021, + 6, + 14, + 11, + 3, + 0, + 803000, + ), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "FeatureGroupStatus": "Created", + "NextToken": "", +} diff --git a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLFeatureTableProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLFeatureTableProperties.pdl index 7666e8296cf9b..034d1dbaf11db 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLFeatureTableProperties.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLFeatureTableProperties.pdl @@ -1,6 +1,7 @@ namespace com.linkedin.ml.metadata import com.linkedin.common.Urn +import com.linkedin.common.CustomProperties /** * Properties associated with a MLFeatureTable @@ -8,7 +9,7 @@ import com.linkedin.common.Urn @Aspect = { "name": "mlFeatureTableProperties" } -record MLFeatureTableProperties { +record MLFeatureTableProperties includes CustomProperties { /** * Documentation of the MLFeatureTable